Detailed changes
@@ -223,6 +223,7 @@ name = "anthropic"
version = "0.1.0"
dependencies = [
"anyhow",
+ "chrono",
"futures 0.3.30",
"http_client",
"isahc",
@@ -232,6 +233,7 @@ dependencies = [
"strum",
"thiserror",
"tokio",
+ "util",
]
[[package]]
@@ -5057,9 +5059,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "heed"
-version = "0.20.4"
+version = "0.20.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "620033c8c8edfd2f53e6f99a30565eb56a33b42c468e3ad80e21d85fb93bafb0"
+checksum = "7d4f449bab7320c56003d37732a917e18798e2f1709d80263face2b4f9436ddb"
dependencies = [
"bitflags 2.6.0",
"byteorder",
@@ -6314,9 +6316,9 @@ dependencies = [
[[package]]
name = "lmdb-master-sys"
-version = "0.2.3"
+version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1de7e761853c15ca72821d9f928d7bb123ef4c05377c4e7ab69fa1c742f91d24"
+checksum = "472c3760e2a8d0f61f322fb36788021bb36d573c502b50fa3e2bcaac3ec326c9"
dependencies = [
"cc",
"doxygen-rs",
@@ -7590,6 +7592,29 @@ version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
+[[package]]
+name = "performance"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "gpui",
+ "log",
+ "schemars",
+ "serde",
+ "settings",
+ "util",
+ "workspace",
+]
+
+[[package]]
+name = "perplexity"
+version = "0.1.0"
+dependencies = [
+ "serde",
+ "zed_extension_api 0.1.0",
+]
+
[[package]]
name = "pest"
version = "2.7.11"
@@ -13875,6 +13900,7 @@ dependencies = [
"outline_panel",
"parking_lot",
"paths",
+ "performance",
"profiling",
"project",
"project_panel",
@@ -70,6 +70,7 @@ members = [
"crates/outline",
"crates/outline_panel",
"crates/paths",
+ "crates/performance",
"crates/picker",
"crates/prettier",
"crates/project",
@@ -145,6 +146,7 @@ members = [
"extensions/lua",
"extensions/ocaml",
"extensions/php",
+ "extensions/perplexity",
"extensions/prisma",
"extensions/purescript",
"extensions/ruff",
@@ -241,6 +243,7 @@ open_ai = { path = "crates/open_ai" }
outline = { path = "crates/outline" }
outline_panel = { path = "crates/outline_panel" }
paths = { path = "crates/paths" }
+performance = { path = "crates/performance" }
picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
plugin_macros = { path = "crates/plugin_macros" }
@@ -0,0 +1,11 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<g clip-path="url(#clip0_1896_18)">
+<path d="M11.094 3.09999H8.952L12.858 12.9H15L11.094 3.09999Z" fill="#1F1F1E"/>
+<path d="M4.906 3.09999L1 12.9H3.184L3.98284 10.842H8.06915L8.868 12.9H11.052L7.146 3.09999H4.906ZM4.68928 9.02199L6.026 5.57799L7.3627 9.02199H4.68928Z" fill="#1F1F1E"/>
+</g>
+<defs>
+<clipPath id="clip0_1896_18">
+<rect width="14" height="9.8" fill="white" transform="translate(1 3.09999)"/>
+</clipPath>
+</defs>
+</svg>
@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-ellipsis-vertical"><circle cx="12" cy="12" r="1"/><circle cx="12" cy="5" r="1"/><circle cx="12" cy="19" r="1"/></svg>
@@ -1,10 +0,0 @@
-<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
- <path d="M3 13L7.01562 8.98438" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M8.6875 7.3125L9.5 6.5" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M7 5V3" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M12 5V3" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M12 10V8" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M6 4L8 4" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M11 4L13 4" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
- <path d="M11 9L13 9" stroke="black" stroke-width="1.5" stroke-linecap="round"/>
-</svg>
@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-search-code"><path d="m13 13.5 2-2.5-2-2.5"/><path d="m21 21-4.3-4.3"/><path d="M9 8.5 7 11l2 2.5"/><circle cx="11" cy="11" r="8"/></svg>
@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-slash"><path d="M22 2 2 22"/></svg>
@@ -0,0 +1 @@
+<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" class="lucide lucide-square-slash"><rect width="18" height="18" x="3" y="3" rx="2"/><line x1="9" x2="15" y1="15" y2="9"/></svg>
@@ -1,426 +1,61 @@
-You are an expert developer assistant working in an AI-enabled text editor.
-Your task is to rewrite a specific section of the provided document based on a user-provided prompt.
-
-<guidelines>
-1. Scope: Modify only content within <rewrite_this> tags. Do not alter anything outside these boundaries.
-2. Precision: Make changes strictly necessary to fulfill the given prompt. Preserve all other content as-is.
-3. Seamless integration: Ensure rewritten sections flow naturally with surrounding text and maintain document structure.
-4. Tag exclusion: Never include <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags in the output.
-5. Indentation: Maintain the original indentation level of the file in rewritten sections.
-6. Completeness: Rewrite the entire tagged section, even if only partial changes are needed. Avoid omissions or elisions.
-7. Insertions: Replace <insert_here></insert_here> tags with appropriate content as specified by the prompt.
-8. Code integrity: Respect existing code structure and functionality when making changes.
-9. Consistency: Maintain a uniform style and tone throughout the rewritten text.
-</guidelines>
-
-<examples>
-<example>
-<input>
-<document>
-use std::cell::Cell;
-use std::collections::HashMap;
-use std::cmp;
-
-<rewrite_this>
-<insert_here></insert_here>
-</rewrite_this>
-pub struct LruCache<K, V> {
- /// The maximum number of items the cache can hold.
- capacity: usize,
- /// The map storing the cached items.
- items: HashMap<K, V>,
-}
-
-// The rest of the implementation...
-</document>
-<prompt>
-doc this
-</prompt>
-</input>
-
-<incorrect_output failure="Over-generation. The text starting with `pub struct AabbTree<T> {` is *after* the rewrite_this tag">
-/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
-///
-/// This structure is used for efficient spatial queries and collision detection.
-/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
-///
-/// # Type Parameters
-///
-/// * `T`: The type of data associated with each node in the tree.
-pub struct AabbTree<T> {
- root: Option<usize>,
-</incorrect_output>
-<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
-/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
-///
-/// This structure is used for efficient spatial queries and collision detection.
-/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
-///
-/// # Type Parameters
-///
-/// * `T`: The type of data associated with each node in the tree.
-</corrected_output>
-</example>
-
-<example>
-<input>
-<document>
-import math
-
-def calculate_circle_area(radius):
- """Calculate the area of a circle given its radius."""
- return math.pi * radius ** 2
-
-<rewrite_this>
-<insert_here></insert_here>
-</rewrite_this>
-class Circle:
- def __init__(self, radius):
- self.radius = radius
-
- def area(self):
- return math.pi * self.radius ** 2
-
- def circumference(self):
- return 2 * math.pi * self.radius
-
-# Usage example
-circle = Circle(5)
-print(f"Area: {circle.area():.2f}")
-print(f"Circumference: {circle.circumference():.2f}")
-</document>
-<prompt>
-write docs
-</prompt>
-</input>
-
-<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
-"""
-Represents a circle with methods to calculate its area and circumference.
-
-This class provides a simple way to work with circles in a geometric context.
-It allows for the creation of Circle objects with a specified radius and
-offers methods to compute the circle's area and circumference.
-
-Attributes:
- radius (float): The radius of the circle.
-
-Methods:
- area(): Calculates and returns the area of the circle.
- circumference(): Calculates and returns the circumference of the circle.
-"""
-class Circle:
-</incorrect_output>
-<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
-"""
-Represents a circle with methods to calculate its area and circumference.
-
-This class provides a simple way to work with circles in a geometric context.
-It allows for the creation of Circle objects with a specified radius and
-offers methods to compute the circle's area and circumference.
-
-Attributes:
- radius (float): The radius of the circle.
+{{#if language_name}}
+Here's a file of {{language_name}} that I'm going to ask you to make an edit to.
+{{else}}
+Here's a file of text that I'm going to ask you to make an edit to.
+{{/if}}
-Methods:
- area(): Calculates and returns the area of the circle.
- circumference(): Calculates and returns the circumference of the circle.
-"""
-</corrected_output>
-</example>
+{{#if is_insert}}
+The point you'll need to insert at is marked with <insert_here></insert_here>.
+{{else}}
+The section you'll need to rewrite is marked with <rewrite_this></rewrite_this> tags.
+{{/if}}
-<example>
-<input>
<document>
-class BankAccount {
- private balance: number;
-
- constructor(initialBalance: number) {
- this.balance = initialBalance;
- }
-
- <rewrite_this>
- <insert_here></insert_here>
- </rewrite_this>
- deposit(amount: number): void {
- if (amount > 0) {
- this.balance += amount;
- }
- }
-
- withdraw(amount: number): boolean {
- if (amount > 0 && this.balance >= amount) {
- this.balance -= amount;
- return true;
- }
- return false;
- }
-
- getBalance(): number {
- return this.balance;
- }
-}
-
-// Usage
-const account = new BankAccount(1000);
-account.deposit(500);
-console.log(account.getBalance()); // 1500
-account.withdraw(200);
-console.log(account.getBalance()); // 1300
+{{{document_content}}}
</document>
-<prompt>
-//
-</prompt>
-</input>
-<incorrect_output failure="Over-generation. The text starting with `deposit(amount: number): void {` is *after* the rewrite_this tag">
- /**
- * Deposits the specified amount into the bank account.
- *
- * @param amount The amount to deposit. Must be a positive number.
- * @throws Error if the amount is not positive.
- */
- deposit(amount: number): void {
- if (amount > 0) {
- this.balance += amount;
- } else {
- throw new Error("Deposit amount must be positive");
- }
- }
-</incorrect_output>
-<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
- /**
- * Deposits the specified amount into the bank account.
- *
- * @param amount The amount to deposit. Must be a positive number.
- * @throws Error if the amount is not positive.
- */
-</corrected_output>
-</example>
+{{#if is_truncated}}
+The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
+{{/if}}
-<example>
-<input>
-<document>
-use std::collections::VecDeque;
+{{#if is_insert}}
+You can't replace {{content_type}}, your answer will be inserted in place of the `<insert_here></insert_here>` tags. Don't include the insert_here tags in your output.
-pub struct BinaryTree<T> {
- root: Option<Node<T>>,
-}
+Generate {{content_type}} based on the following prompt:
-<rewrite_this>
-<insert_here></insert_here>
-</rewrite_this>
-struct Node<T> {
- value: T,
- left: Option<Box<Node<T>>>,
- right: Option<Box<Node<T>>>,
-}
-</document>
<prompt>
-derive clone
+{{{user_prompt}}}
</prompt>
-</input>
-
-<incorrect_output failure="Over-generation below the rewrite_this tags. Extra space between derive annotation and struct definition.">
-#[derive(Clone)]
-
-struct Node<T> {
- value: T,
- left: Option<Box<Node<T>>>,
- right: Option<Box<Node<T>>>,
-}
-</incorrect_output>
-<incorrect_output failure="Over-generation above the rewrite_this tags">
-pub struct BinaryTree<T> {
- root: Option<Node<T>>,
-}
+Match the indentation in the original file in the inserted {{content_type}}, don't include any indentation on blank lines.
-#[derive(Clone)]
-</incorrect_output>
-
-<incorrect_output failure="Over-generation below the rewrite_this tags">
-#[derive(Clone)]
-struct Node<T> {
- value: T,
- left: Option<Box<Node<T>>>,
- right: Option<Box<Node<T>>>,
-}
-
-impl<T> Node<T> {
- fn new(value: T) -> Self {
- Node {
- value,
- left: None,
- right: None,
- }
- }
-}
-</incorrect_output>
-<corrected_output improvement="Only includes the new content within the rewrite_this tags">
-#[derive(Clone)]
-</corrected_output>
-</example>
-
-<example>
-<input>
-<document>
-import math
-
-def calculate_circle_area(radius):
- """Calculate the area of a circle given its radius."""
- return math.pi * radius ** 2
-
-<rewrite_this>
-<insert_here></insert_here>
-</rewrite_this>
-class Circle:
- def __init__(self, radius):
- self.radius = radius
-
- def area(self):
- return math.pi * self.radius ** 2
+Immediately start with the following format with no remarks:
- def circumference(self):
- return 2 * math.pi * self.radius
+```
+\{{INSERTED_CODE}}
+```
+{{else}}
+Edit the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
-# Usage example
-circle = Circle(5)
-print(f"Area: {circle.area():.2f}")
-print(f"Circumference: {circle.circumference():.2f}")
-</document>
<prompt>
-add dataclass decorator
+{{{user_prompt}}}
</prompt>
-</input>
-
-<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
-@dataclass
-class Circle:
- radius: float
-
- def __init__(self, radius):
- self.radius = radius
- def area(self):
- return math.pi * self.radius ** 2
-</incorrect_output>
-<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
-@dataclass
-</corrected_output>
-</example>
-
-<example>
-<input>
-<document>
-interface ShoppingCart {
- items: string[];
- total: number;
-}
+{{#if rewrite_section}}
+And here's the section to rewrite based on that prompt again for reference:
<rewrite_this>
-<insert_here></insert_here>class ShoppingCartManager {
+{{{rewrite_section}}}
</rewrite_this>
- private cart: ShoppingCart;
-
- constructor() {
- this.cart = { items: [], total: 0 };
- }
-
- addItem(item: string, price: number): void {
- this.cart.items.push(item);
- this.cart.total += price;
- }
-
- getTotal(): number {
- return this.cart.total;
- }
-}
-
-// Usage
-const manager = new ShoppingCartManager();
-manager.addItem("Book", 15.99);
-console.log(manager.getTotal()); // 15.99
-</document>
-<prompt>
-add readonly modifier
-</prompt>
-</input>
-
-<incorrect_output failure="Over-generation. The line starting with ` items: string[];` is *after* the rewrite_this tag">
-readonly interface ShoppingCart {
- items: string[];
- total: number;
-}
-
-class ShoppingCartManager {
- private readonly cart: ShoppingCart;
-
- constructor() {
- this.cart = { items: [], total: 0 };
- }
-</incorrect_output>
-<corrected_output improvement="Only includes the new content within the rewrite_this tags and integrates cleanly into surrounding code">
-readonly interface ShoppingCart {
-</corrected_output>
-</example>
-
-</examples>
-
-With these examples in mind, edit the following file:
-
-<document language="{{ language_name }}">
-{{{ document_content }}}
-</document>
-
-{{#if is_truncated}}
-The provided document has been truncated (potentially mid-line) for brevity.
-{{/if}}
-
-<instructions>
-{{#if has_insertion}}
-Insert text anywhere you see marked with <insert_here></insert_here> tags. It's CRITICAL that you DO NOT include <insert_here> tags in your output.
-{{/if}}
-{{#if has_replacement}}
-Edit text that you see surrounded with <edit_here>...</edit_here> tags. It's CRITICAL that you DO NOT include <edit_here> tags in your output.
{{/if}}
-Make no changes to the rewritten content outside these tags.
-<snippet language="{{ language_name }}" annotated="true">
-{{{ rewrite_section_prefix }}}
-<rewrite_this>
-{{{ rewrite_section_with_edits }}}
-</rewrite_this>
-{{{ rewrite_section_suffix }}}
-</snippet>
+Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved.
-Rewrite the lines enclosed within the <rewrite_this></rewrite_this> tags in accordance with the provided instructions and the prompt below.
-
-<prompt>
-{{{ user_prompt }}}
-</prompt>
-
-Do not include <insert_here> or <edit_here> annotations in your output. Here is a clean copy of the snippet without annotations for your reference.
-
-<snippet>
-{{{ rewrite_section_prefix }}}
-{{{ rewrite_section }}}
-{{{ rewrite_section_suffix }}}
-</snippet>
-</instructions>
-
-<guidelines_reminder>
-1. Focus on necessary changes: Modify only what's required to fulfill the prompt.
-2. Preserve context: Maintain all surrounding content as-is, ensuring the rewritten section seamlessly integrates with the existing document structure and flow.
-3. Exclude annotation tags: Do not output <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags.
-4. Maintain indentation: Begin at the original file's indentation level.
-5. Complete rewrite: Continue until the entire section is rewritten, even if no further changes are needed.
-6. Avoid elisions: Always write out the full section without unnecessary omissions. NEVER say `// ...` or `// ...existing code` in your output.
-7. Respect content boundaries: Preserve code integrity.
-</guidelines_reminder>
+Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make, always write out the whole section with no unnecessary elisions.
Immediately start with the following format with no remarks:
```
\{{REWRITTEN_CODE}}
```
+{{/if}}
@@ -17,6 +17,7 @@ path = "src/anthropic.rs"
[dependencies]
anyhow.workspace = true
+chrono.workspace = true
futures.workspace = true
http_client.workspace = true
isahc.workspace = true
@@ -25,6 +26,7 @@ serde.workspace = true
serde_json.workspace = true
strum.workspace = true
thiserror.workspace = true
+util.workspace = true
[dev-dependencies]
tokio.workspace = true
@@ -1,14 +1,17 @@
mod supported_countries;
use anyhow::{anyhow, Context, Result};
+use chrono::{DateTime, Utc};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
+use isahc::http::{HeaderMap, HeaderValue};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
+use util::ResultExt as _;
pub use supported_countries::*;
@@ -38,6 +41,8 @@ pub enum Model {
Custom {
name: String,
max_tokens: usize,
+ /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+ display_name: Option<String>,
/// Override this model with a different Anthropic model for tool calls.
tool_override: Option<String>,
/// Indicates whether this custom model supports caching.
@@ -77,7 +82,9 @@ impl Model {
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
- Self::Custom { name, .. } => name,
+ Self::Custom {
+ name, display_name, ..
+ } => display_name.as_ref().unwrap_or(name),
}
}
@@ -191,6 +198,66 @@ pub async fn stream_completion(
request: Request,
low_speed_timeout: Option<Duration>,
) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
+ stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
+ .await
+ .map(|output| output.0)
+}
+
+/// https://docs.anthropic.com/en/api/rate-limits#response-headers
+#[derive(Debug)]
+pub struct RateLimitInfo {
+ pub requests_limit: usize,
+ pub requests_remaining: usize,
+ pub requests_reset: DateTime<Utc>,
+ pub tokens_limit: usize,
+ pub tokens_remaining: usize,
+ pub tokens_reset: DateTime<Utc>,
+}
+
+impl RateLimitInfo {
+ fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
+ let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
+ let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
+ let tokens_remaining =
+ get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
+ let requests_remaining =
+ get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
+ let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
+ let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
+ let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
+ let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
+
+ Ok(Self {
+ requests_limit,
+ tokens_limit,
+ requests_remaining,
+ tokens_remaining,
+ requests_reset,
+ tokens_reset,
+ })
+ }
+}
+
+fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
+ Ok(headers
+ .get(key)
+ .ok_or_else(|| anyhow!("missing header `{key}`"))?
+ .to_str()?)
+}
+
+pub async fn stream_completion_with_rate_limit_info(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+ low_speed_timeout: Option<Duration>,
+) -> Result<
+ (
+ BoxStream<'static, Result<Event, AnthropicError>>,
+ Option<RateLimitInfo>,
+ ),
+ AnthropicError,
+> {
let request = StreamingRequest {
base: request,
stream: true,
@@ -220,8 +287,9 @@ pub async fn stream_completion(
.await
.context("failed to send request to Anthropic")?;
if response.status().is_success() {
+ let rate_limits = RateLimitInfo::from_headers(response.headers());
let reader = BufReader::new(response.into_body());
- Ok(reader
+ let stream = reader
.lines()
.filter_map(|line| async move {
match line {
@@ -235,7 +303,8 @@ pub async fn stream_completion(
Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
}
})
- .boxed())
+ .boxed();
+ Ok((stream, rate_limits.log_err()))
} else {
let mut body = Vec::new();
response
@@ -9,6 +9,7 @@ mod model_selector;
mod prompt_library;
mod prompts;
mod slash_command;
+pub(crate) mod slash_command_picker;
pub mod slash_command_settings;
mod streaming_diff;
mod terminal_inline_assistant;
@@ -33,7 +34,7 @@ use language_model::{
};
pub(crate) use model_selector::*;
pub use prompts::PromptBuilder;
-use prompts::PromptOverrideContext;
+use prompts::PromptLoadingParams;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
@@ -59,7 +60,6 @@ actions!(
InsertIntoEditor,
ToggleFocus,
InsertActivePrompt,
- ShowConfiguration,
DeployHistory,
DeployPromptLibrary,
ConfirmCommand,
@@ -184,7 +184,7 @@ impl Assistant {
pub fn init(
fs: Arc<dyn Fs>,
client: Arc<Client>,
- dev_mode: bool,
+ stdout_is_a_pty: bool,
cx: &mut AppContext,
) -> Arc<PromptBuilder> {
cx.set_global(Assistant::default());
@@ -223,9 +223,11 @@ pub fn init(
assistant_panel::init(cx);
context_servers::init(cx);
- let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
- dev_mode,
+ let prompt_builder = prompts::PromptBuilder::new(Some(PromptLoadingParams {
fs: fs.clone(),
+ repo_path: stdout_is_a_pty
+ .then(|| std::env::current_dir().log_err())
+ .flatten(),
cx,
}))
.log_err()
@@ -9,6 +9,7 @@ use crate::{
file_command::codeblock_fence_for_path,
SlashCommandCompletionProvider, SlashCommandRegistry,
},
+ slash_command_picker,
terminal_inline_assistant::TerminalInlineAssistant,
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, InlineAssistant,
@@ -16,7 +17,7 @@ use crate::{
QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus,
ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
};
-use crate::{ContextStoreEvent, ModelPickerDelegate, ShowConfiguration};
+use crate::{ContextStoreEvent, ModelPickerDelegate};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::{proto, Client, Status};
@@ -35,10 +36,10 @@ use fs::Fs;
use gpui::{
canvas, div, img, percentage, point, pulsating_between, size, Action, Animation, AnimationExt,
AnyElement, AnyView, AppContext, AsyncWindowContext, ClipboardEntry, ClipboardItem,
- Context as _, CursorStyle, DismissEvent, Empty, Entity, EntityId, EventEmitter, FocusHandle,
- FocusableView, FontWeight, InteractiveElement, IntoElement, Model, ParentElement, Pixels,
- ReadGlobal, Render, RenderImage, SharedString, Size, StatefulInteractiveElement, Styled,
- Subscription, Task, Transformation, UpdateGlobal, View, VisualContext, WeakView, WindowContext,
+ Context as _, DismissEvent, Empty, Entity, EntityId, EventEmitter, FocusHandle, FocusableView,
+ FontWeight, InteractiveElement, IntoElement, Model, ParentElement, Pixels, ReadGlobal, Render,
+ RenderImage, SharedString, Size, StatefulInteractiveElement, Styled, Subscription, Task,
+ Transformation, UpdateGlobal, View, VisualContext, WeakView, WindowContext,
};
use indexed_docs::IndexedDocsStore;
use language::{
@@ -68,8 +69,8 @@ use ui::TintColor;
use ui::{
prelude::*,
utils::{format_distance_from_now, DateTimeType},
- Avatar, AvatarShape, ButtonLike, ContextMenu, Disclosure, ElevationIndex, KeyBinding, ListItem,
- ListItemSpacing, PopoverMenu, PopoverMenuHandle, Tooltip,
+ Avatar, AvatarShape, ButtonLike, ContextMenu, Disclosure, ElevationIndex, IconButtonShape,
+ KeyBinding, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, Tooltip,
};
use util::ResultExt;
use workspace::{
@@ -77,7 +78,8 @@ use workspace::{
item::{self, FollowableItem, Item, ItemHandle},
pane::{self, SaveIntent},
searchable::{SearchEvent, SearchableItem},
- Pane, Save, ToggleZoom, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
+ Pane, Save, ShowConfiguration, ToggleZoom, ToolbarItemEvent, ToolbarItemLocation,
+ ToolbarItemView, Workspace,
};
use workspace::{searchable::SearchableItemHandle, NewFile};
@@ -119,7 +121,7 @@ struct InlineAssistTabBarButton;
impl Render for InlineAssistTabBarButton {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- IconButton::new("terminal_inline_assistant", IconName::MagicWand)
+ IconButton::new("terminal_inline_assistant", IconName::ZedAssistant)
.icon_size(IconSize::Small)
.on_click(cx.listener(|_, _, cx| {
cx.dispatch_action(InlineAssist::default().boxed_clone());
@@ -541,19 +543,18 @@ impl AssistantPanel {
cx.emit(AssistantPanelEvent::ContextEdited);
true
}
-
- pane::Event::RemoveItem { idx } => {
- if self
+ pane::Event::RemovedItem { .. } => {
+ let has_configuration_view = self
.pane
.read(cx)
- .item_for_index(*idx)
- .map_or(false, |item| item.downcast::<ConfigurationView>().is_some())
- {
+ .items_of_type::<ConfigurationView>()
+ .next()
+ .is_some();
+
+ if !has_configuration_view {
self.configuration_subscription = None;
}
- false
- }
- pane::Event::RemovedItem { .. } => {
+
cx.emit(AssistantPanelEvent::ContextEdited);
true
}
@@ -704,7 +705,9 @@ impl AssistantPanel {
self.authenticate_provider_task = Some((
provider.id(),
cx.spawn(|this, mut cx| async move {
- let _ = load_credentials.await;
+ if let Some(future) = load_credentials {
+ let _ = future.await;
+ }
this.update(&mut cx, |this, _cx| {
this.authenticate_provider_task = None;
})
@@ -735,6 +738,7 @@ impl AssistantPanel {
};
let initial_prompt = action.prompt.clone();
+
if assistant_panel.update(cx, |assistant, cx| assistant.is_authenticated(cx)) {
match inline_assist_target {
InlineAssistTarget::Editor(active_editor, include_context) => {
@@ -763,9 +767,27 @@ impl AssistantPanel {
} else {
let assistant_panel = assistant_panel.downgrade();
cx.spawn(|workspace, mut cx| async move {
- assistant_panel
- .update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
- .await?;
+ let Some(task) =
+ assistant_panel.update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
+ else {
+ let answer = cx
+ .prompt(
+ gpui::PromptLevel::Warning,
+ "No language model provider configured",
+ None,
+ &["Configure", "Cancel"],
+ )
+ .await
+ .ok();
+ if let Some(answer) = answer {
+ if answer == 0 {
+ cx.update(|cx| cx.dispatch_action(Box::new(ShowConfiguration)))
+ .ok();
+ }
+ }
+ return Ok(());
+ };
+ task.await?;
if assistant_panel.update(&mut cx, |panel, cx| panel.is_authenticated(cx))? {
cx.update(|cx| match inline_assist_target {
InlineAssistTarget::Editor(active_editor, include_context) => {
@@ -1173,13 +1195,10 @@ impl AssistantPanel {
.map_or(false, |provider| provider.is_authenticated(cx))
}
- fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+ fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Option<Task<Result<()>>> {
LanguageModelRegistry::read_global(cx)
.active_provider()
- .map_or(
- Task::ready(Err(anyhow!("no active language model provider"))),
- |provider| provider.authenticate(cx),
- )
+ .map_or(None, |provider| Some(provider.authenticate(cx)))
}
}
@@ -1336,6 +1355,7 @@ struct WorkflowStep {
footer_block_id: CustomBlockId,
resolved_step: Option<Result<WorkflowStepResolution, Arc<anyhow::Error>>>,
assist: Option<WorkflowAssist>,
+ auto_apply: bool,
}
impl WorkflowStep {
@@ -1372,13 +1392,16 @@ impl WorkflowStep {
}
}
Some(Err(error)) => WorkflowStepStatus::Error(error.clone()),
- None => WorkflowStepStatus::Resolving,
+ None => WorkflowStepStatus::Resolving {
+ auto_apply: self.auto_apply,
+ },
}
}
}
+#[derive(Clone)]
enum WorkflowStepStatus {
- Resolving,
+ Resolving { auto_apply: bool },
Error(Arc<anyhow::Error>),
Empty,
Idle,
@@ -1455,16 +1478,6 @@ impl WorkflowStepStatus {
.unwrap_or_default()
}
match self {
- WorkflowStepStatus::Resolving => Label::new("Resolving")
- .size(LabelSize::Small)
- .with_animation(
- ("resolving-suggestion-animation", id),
- Animation::new(Duration::from_secs(2))
- .repeat()
- .with_easing(pulsating_between(0.4, 0.8)),
- |label, delta| label.alpha(delta),
- )
- .into_any_element(),
WorkflowStepStatus::Error(error) => Self::render_workflow_step_error(
id,
editor.clone(),
@@ -1477,43 +1490,72 @@ impl WorkflowStepStatus {
step_range.clone(),
"Model was unable to locate the code to edit".to_string(),
),
- WorkflowStepStatus::Idle => Button::new(("transform", id), "Transform")
- .icon(IconName::SparkleAlt)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .label_size(LabelSize::Small)
- .style(ButtonStyle::Tinted(TintColor::Accent))
- .tooltip({
- let step_range = step_range.clone();
- let editor = editor.clone();
- move |cx| {
- cx.new_view(|cx| {
- let tooltip = Tooltip::new("Transform");
- if display_keybind_in_tooltip(&step_range, &editor, cx) {
- tooltip.key_binding(KeyBinding::for_action_in(
- &Assist,
- &focus_handle,
- cx,
- ))
- } else {
- tooltip
- }
- })
- .into()
- }
- })
- .on_click({
- let editor = editor.clone();
- let step_range = step_range.clone();
- move |_, cx| {
- editor
- .update(cx, |this, cx| {
- this.apply_workflow_step(step_range.clone(), cx)
+ WorkflowStepStatus::Idle | WorkflowStepStatus::Resolving { .. } => {
+ let status = self.clone();
+ Button::new(("transform", id), "Transform")
+ .icon(IconName::SparkleAlt)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::Small)
+ .label_size(LabelSize::Small)
+ .style(ButtonStyle::Tinted(TintColor::Accent))
+ .tooltip({
+ let step_range = step_range.clone();
+ let editor = editor.clone();
+ move |cx| {
+ cx.new_view(|cx| {
+ let tooltip = Tooltip::new("Transform");
+ if display_keybind_in_tooltip(&step_range, &editor, cx) {
+ tooltip.key_binding(KeyBinding::for_action_in(
+ &Assist,
+ &focus_handle,
+ cx,
+ ))
+ } else {
+ tooltip
+ }
})
- .ok();
- }
- })
- .into_any_element(),
+ .into()
+ }
+ })
+ .on_click({
+ let editor = editor.clone();
+ let step_range = step_range.clone();
+ move |_, cx| {
+ if let WorkflowStepStatus::Idle = &status {
+ editor
+ .update(cx, |this, cx| {
+ this.apply_workflow_step(step_range.clone(), cx)
+ })
+ .ok();
+ } else if let WorkflowStepStatus::Resolving { auto_apply: false } =
+ &status
+ {
+ editor
+ .update(cx, |this, _| {
+ if let Some(step) = this.workflow_steps.get_mut(&step_range)
+ {
+ step.auto_apply = true;
+ }
+ })
+ .ok();
+ }
+ }
+ })
+ .map(|this| {
+ if let WorkflowStepStatus::Resolving { auto_apply: true } = &self {
+ this.with_animation(
+ ("resolving-suggestion-animation", id),
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element()
+ } else {
+ this.into_any_element()
+ }
+ })
+ }
WorkflowStepStatus::Pending => h_flex()
.items_center()
.gap_2()
@@ -1699,6 +1741,8 @@ pub struct ContextEditor {
assistant_panel: WeakView<AssistantPanel>,
error_message: Option<SharedString>,
show_accept_terms: bool,
+ pub(crate) slash_menu_handle:
+ PopoverMenuHandle<Picker<slash_command_picker::SlashCommandDelegate>>,
}
const DEFAULT_TAB_TITLE: &str = "New Context";
@@ -1760,10 +1804,11 @@ impl ContextEditor {
assistant_panel,
error_message: None,
show_accept_terms: false,
+ slash_menu_handle: Default::default(),
};
this.update_message_headers(cx);
this.update_image_blocks(cx);
- this.insert_slash_command_output_sections(sections, cx);
+ this.insert_slash_command_output_sections(sections, false, cx);
this
}
@@ -1776,7 +1821,7 @@ impl ContextEditor {
let command = self.context.update(cx, |context, cx| {
let first_message_id = context.messages(cx).next().unwrap().id;
context.update_metadata(first_message_id, cx, |metadata| {
- metadata.role = Role::System;
+ metadata.role = Role::User;
});
context.reparse_slash_commands(cx);
context.pending_slash_commands()[0].clone()
@@ -1787,6 +1832,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
false,
+ true,
self.workspace.clone(),
cx,
);
@@ -1834,7 +1880,7 @@ impl ContextEditor {
let range = step.range.clone();
match step.status(cx) {
- WorkflowStepStatus::Resolving | WorkflowStepStatus::Pending => true,
+ WorkflowStepStatus::Resolving { .. } | WorkflowStepStatus::Pending => true,
WorkflowStepStatus::Idle => {
self.apply_workflow_step(range, cx);
true
@@ -1988,7 +2034,7 @@ impl ContextEditor {
.collect()
}
- fn insert_command(&mut self, name: &str, cx: &mut ViewContext<Self>) {
+ pub fn insert_command(&mut self, name: &str, cx: &mut ViewContext<Self>) {
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
self.editor.update(cx, |editor, cx| {
editor.transact(cx, |editor, cx| {
@@ -2053,6 +2099,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
true,
+ false,
workspace.clone(),
cx,
);
@@ -2061,19 +2108,27 @@ impl ContextEditor {
}
}
+ #[allow(clippy::too_many_arguments)]
pub fn run_command(
&mut self,
command_range: Range<language::Anchor>,
name: &str,
arguments: &[String],
- insert_trailing_newline: bool,
+ ensure_trailing_newline: bool,
+ expand_result: bool,
workspace: WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) {
if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
let output = command.run(arguments, workspace, self.lsp_adapter_delegate.clone(), cx);
self.context.update(cx, |context, cx| {
- context.insert_command_output(command_range, output, insert_trailing_newline, cx)
+ context.insert_command_output(
+ command_range,
+ output,
+ ensure_trailing_newline,
+ expand_result,
+ cx,
+ )
});
}
}
@@ -2159,6 +2214,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
false,
+ false,
workspace.clone(),
cx,
);
@@ -2255,8 +2311,13 @@ impl ContextEditor {
output_range,
sections,
run_commands_in_output,
+ expand_result,
} => {
- self.insert_slash_command_output_sections(sections.iter().cloned(), cx);
+ self.insert_slash_command_output_sections(
+ sections.iter().cloned(),
+ *expand_result,
+ cx,
+ );
if *run_commands_in_output {
let commands = self.context.update(cx, |context, cx| {
@@ -2272,6 +2333,7 @@ impl ContextEditor {
&command.name,
&command.arguments,
false,
+ false,
self.workspace.clone(),
cx,
);
@@ -2288,6 +2350,7 @@ impl ContextEditor {
fn insert_slash_command_output_sections(
&mut self,
sections: impl IntoIterator<Item = SlashCommandOutputSection<language::Anchor>>,
+ expand_result: bool,
cx: &mut ViewContext<Self>,
) {
self.editor.update(cx, |editor, cx| {
@@ -2342,6 +2405,9 @@ impl ContextEditor {
editor.insert_creases(creases, cx);
+ if expand_result {
+ buffer_rows_to_fold.clear();
+ }
for buffer_row in buffer_rows_to_fold.into_iter().rev() {
editor.fold_at(&FoldAt { buffer_row }, cx);
}
@@ -2517,19 +2583,35 @@ impl ContextEditor {
div().child(step_label)
};
- let step_label = step_label
+ let step_label_element = step_label.into_any_element();
+
+ let step_label = h_flex()
.id("step")
- .cursor(CursorStyle::PointingHand)
- .on_click({
- let this = weak_self.clone();
- let step_range = step_range.clone();
- move |_, cx| {
- this.update(cx, |this, cx| {
- this.open_workflow_step(step_range.clone(), cx);
- })
- .ok();
- }
- });
+ .group("step-label")
+ .items_center()
+ .gap_1()
+ .child(step_label_element)
+ .child(
+ IconButton::new("edit-step", IconName::SearchCode)
+ .size(ButtonSize::Compact)
+ .icon_size(IconSize::Small)
+ .shape(IconButtonShape::Square)
+ .visible_on_hover("step-label")
+ .tooltip(|cx| Tooltip::text("Open Step View", cx))
+ .on_click({
+ let this = weak_self.clone();
+ let step_range = step_range.clone();
+ move |_, cx| {
+ this.update(cx, |this, cx| {
+ this.open_workflow_step(
+ step_range.clone(),
+ cx,
+ );
+ })
+ .ok();
+ }
+ }),
+ );
div()
.w_full()
@@ -2612,11 +2694,17 @@ impl ContextEditor {
footer_block_id: block_ids[1],
resolved_step,
assist: None,
+ auto_apply: false,
},
);
}
self.update_active_workflow_step(cx);
+ if let Some(step) = self.workflow_steps.get_mut(&step_range) {
+ if step.auto_apply && matches!(step.status(cx), WorkflowStepStatus::Idle) {
+ self.apply_workflow_step(step_range, cx);
+ }
+ }
}
fn open_workflow_step(
@@ -2654,14 +2742,25 @@ impl ContextEditor {
fn update_active_workflow_step(&mut self, cx: &mut ViewContext<Self>) {
let new_step = self.active_workflow_step_for_cursor(cx);
if new_step.as_ref() != self.active_workflow_step.as_ref() {
+ let mut old_editor = None;
+ let mut old_editor_was_open = None;
if let Some(old_step) = self.active_workflow_step.take() {
- self.hide_workflow_step(old_step.range, cx);
+ (old_editor, old_editor_was_open) =
+ self.hide_workflow_step(old_step.range, cx).unzip();
}
+ let mut new_editor = None;
if let Some(new_step) = new_step {
- self.show_workflow_step(new_step.range.clone(), cx);
+ new_editor = self.show_workflow_step(new_step.range.clone(), cx);
self.active_workflow_step = Some(new_step);
}
+
+ if new_editor != old_editor {
+ if let Some((old_editor, old_editor_was_open)) = old_editor.zip(old_editor_was_open)
+ {
+ self.close_workflow_editor(cx, old_editor, old_editor_was_open)
+ }
+ }
}
}
@@ -2669,15 +2768,15 @@ impl ContextEditor {
&mut self,
step_range: Range<language::Anchor>,
cx: &mut ViewContext<Self>,
- ) {
+ ) -> Option<(View<Editor>, bool)> {
let Some(step) = self.workflow_steps.get_mut(&step_range) else {
- return;
+ return None;
};
let Some(assist) = step.assist.as_ref() else {
- return;
+ return None;
};
let Some(editor) = assist.editor.upgrade() else {
- return;
+ return None;
};
if matches!(step.status(cx), WorkflowStepStatus::Idle) {
@@ -2687,32 +2786,42 @@ impl ContextEditor {
assistant.finish_assist(assist_id, true, cx)
}
});
-
- self.workspace
- .update(cx, |workspace, cx| {
- if let Some(pane) = workspace.pane_for(&editor) {
- pane.update(cx, |pane, cx| {
- let item_id = editor.entity_id();
- if !assist.editor_was_open && pane.is_active_preview_item(item_id) {
- pane.close_item_by_id(item_id, SaveIntent::Skip, cx)
- .detach_and_log_err(cx);
- }
- });
- }
- })
- .ok();
+ return Some((editor, assist.editor_was_open));
}
+
+ return None;
+ }
+
+ fn close_workflow_editor(
+ &mut self,
+ cx: &mut ViewContext<ContextEditor>,
+ editor: View<Editor>,
+ editor_was_open: bool,
+ ) {
+ self.workspace
+ .update(cx, |workspace, cx| {
+ if let Some(pane) = workspace.pane_for(&editor) {
+ pane.update(cx, |pane, cx| {
+ let item_id = editor.entity_id();
+ if !editor_was_open && pane.is_active_preview_item(item_id) {
+ pane.close_item_by_id(item_id, SaveIntent::Skip, cx)
+ .detach_and_log_err(cx);
+ }
+ });
+ }
+ })
+ .ok();
}
fn show_workflow_step(
&mut self,
step_range: Range<language::Anchor>,
cx: &mut ViewContext<Self>,
- ) {
+ ) -> Option<View<Editor>> {
let Some(step) = self.workflow_steps.get_mut(&step_range) else {
- return;
+ return None;
};
-
+ let mut editor_to_return = None;
let mut scroll_to_assist_id = None;
match step.status(cx) {
WorkflowStepStatus::Idle => {
@@ -2726,6 +2835,10 @@ impl ContextEditor {
&self.workspace,
cx,
);
+ editor_to_return = step
+ .assist
+ .as_ref()
+ .and_then(|assist| assist.editor.upgrade());
}
}
WorkflowStepStatus::Pending => {
@@ -2747,14 +2860,15 @@ impl ContextEditor {
}
if let Some(assist_id) = scroll_to_assist_id {
- if let Some(editor) = step
+ if let Some(assist_editor) = step
.assist
.as_ref()
.and_then(|assists| assists.editor.upgrade())
{
+ editor_to_return = Some(assist_editor.clone());
self.workspace
.update(cx, |workspace, cx| {
- workspace.activate_item(&editor, false, false, cx);
+ workspace.activate_item(&assist_editor, false, false, cx);
})
.ok();
InlineAssistant::update_global(cx, |assistant, cx| {
@@ -2762,6 +2876,8 @@ impl ContextEditor {
});
}
}
+
+ return editor_to_return;
}
fn open_assists_for_step(
@@ -2805,7 +2921,6 @@ impl ContextEditor {
)
})
.log_err()?;
-
let (&excerpt_id, _, _) = editor
.read(cx)
.buffer()
@@ -3484,7 +3599,8 @@ impl ContextEditor {
};
Some(
h_flex()
- .p_3()
+ .px_3()
+ .py_2()
.border_b_1()
.border_color(cx.theme().colors().border_variant)
.bg(cx.theme().colors().editor_background)
@@ -3520,13 +3636,17 @@ impl ContextEditor {
fn render_send_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let focus_handle = self.focus_handle(cx).clone();
+ let mut should_pulsate = false;
let button_text = match self.active_workflow_step() {
Some(step) => match step.status(cx) {
- WorkflowStepStatus::Resolving => "Resolving Step...",
WorkflowStepStatus::Empty | WorkflowStepStatus::Error(_) => "Retry Step Resolution",
+ WorkflowStepStatus::Resolving { auto_apply } => {
+ should_pulsate = auto_apply;
+ "Transform"
+ }
WorkflowStepStatus::Idle => "Transform",
- WorkflowStepStatus::Pending => "Transforming...",
- WorkflowStepStatus::Done => "Accept Transformation",
+ WorkflowStepStatus::Pending => "Applying...",
+ WorkflowStepStatus::Done => "Accept",
WorkflowStepStatus::Confirmed => "Send",
},
None => "Send",
@@ -3556,12 +3676,12 @@ impl ContextEditor {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let has_configuration_error = configuration_error(cx).is_some();
let needs_to_accept_terms = self.show_accept_terms
&& provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx));
- let has_active_error = self.error_message.is_some();
- let disabled = needs_to_accept_terms || has_active_error;
+ let disabled = has_configuration_error || needs_to_accept_terms;
ButtonLike::new("send_button")
.disabled(disabled)
@@ -3570,11 +3690,24 @@ impl ContextEditor {
button.tooltip(move |_| tooltip.clone())
})
.layer(ElevationIndex::ModalSurface)
+ .child(Label::new(button_text).map(|this| {
+ if should_pulsate {
+ this.with_animation(
+ "resolving-suggestion-send-button-animation",
+ Animation::new(Duration::from_secs(2))
+ .repeat()
+ .with_easing(pulsating_between(0.4, 0.8)),
+ |label, delta| label.alpha(delta),
+ )
+ .into_any_element()
+ } else {
+ this.into_any_element()
+ }
+ }))
.children(
KeyBinding::for_action_in(&Assist, &focus_handle, cx)
.map(|binding| binding.into_any_element()),
)
- .child(Label::new(button_text))
.on_click(move |_event, cx| {
focus_handle.dispatch_action(&Assist, cx);
})
@@ -3604,7 +3737,13 @@ impl Render for ContextEditor {
} else {
None
};
-
+ let focus_handle = self
+ .workspace
+ .update(cx, |workspace, cx| {
+ Some(workspace.active_item_as::<Editor>(cx)?.focus_handle(cx))
+ })
+ .ok()
+ .flatten();
v_flex()
.key_context("ContextEditor")
.capture_action(cx.listener(ContextEditor::cancel))
@@ -3627,8 +3766,8 @@ impl Render for ContextEditor {
this.child(
div()
.absolute()
- .right_4()
- .bottom_10()
+ .right_3()
+ .bottom_12()
.max_w_96()
.py_2()
.px_3()
@@ -3642,8 +3781,8 @@ impl Render for ContextEditor {
this.child(
div()
.absolute()
- .right_4()
- .bottom_10()
+ .right_3()
+ .bottom_12()
.max_w_96()
.py_2()
.px_3()
@@ -3654,12 +3793,12 @@ impl Render for ContextEditor {
.gap_0p5()
.child(
h_flex()
- .gap_1()
+ .gap_1p5()
.items_center()
.child(Icon::new(IconName::XCircle).color(Color::Error))
.child(
Label::new("Error interacting with language model")
- .weight(FontWeight::SEMIBOLD),
+ .weight(FontWeight::MEDIUM),
),
)
.child(
@@ -3681,14 +3820,45 @@ impl Render for ContextEditor {
)
})
.child(
- h_flex().flex_none().relative().child(
+ h_flex().w_full().relative().child(
h_flex()
+ .p_2()
.w_full()
- .absolute()
- .right_4()
- .bottom_2()
- .justify_end()
- .child(self.render_send_button(cx)),
+ .border_t_1()
+ .border_color(cx.theme().colors().border_variant)
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ h_flex()
+ .gap_2()
+ .child(render_inject_context_menu(cx.view().downgrade(), cx))
+ .child(
+ IconButton::new("quote-button", IconName::Quote)
+ .icon_size(IconSize::Small)
+ .on_click(|_, cx| {
+ cx.dispatch_action(QuoteSelection.boxed_clone());
+ })
+ .tooltip(move |cx| {
+ cx.new_view(|cx| {
+ Tooltip::new("Insert Selection").key_binding(
+ focus_handle.as_ref().and_then(|handle| {
+ KeyBinding::for_action_in(
+ &QuoteSelection,
+ &handle,
+ cx,
+ )
+ }),
+ )
+ })
+ .into()
+ }),
+ ),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .justify_end()
+ .child(div().child(self.render_send_button(cx))),
+ ),
),
)
}
@@ -3937,6 +4107,37 @@ pub struct ContextEditorToolbarItem {
model_selector_menu_handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>,
}
+fn active_editor_focus_handle(
+ workspace: &WeakView<Workspace>,
+ cx: &WindowContext<'_>,
+) -> Option<FocusHandle> {
+ workspace.upgrade().and_then(|workspace| {
+ Some(
+ workspace
+ .read(cx)
+ .active_item_as::<Editor>(cx)?
+ .focus_handle(cx),
+ )
+ })
+}
+
+fn render_inject_context_menu(
+ active_context_editor: WeakView<ContextEditor>,
+ cx: &mut WindowContext<'_>,
+) -> impl IntoElement {
+ let commands = SlashCommandRegistry::global(cx);
+
+ slash_command_picker::SlashCommandSelector::new(
+ commands.clone(),
+ active_context_editor,
+ IconButton::new("trigger", IconName::SlashSquare)
+ .icon_size(IconSize::Small)
+ .tooltip(|cx| {
+ Tooltip::with_meta("Insert Context", None, "Type / to insert via keyboard", cx)
+ }),
+ )
+}
+
impl ContextEditorToolbarItem {
pub fn new(
workspace: &Workspace,
@@ -3952,70 +4153,6 @@ impl ContextEditorToolbarItem {
}
}
- fn render_inject_context_menu(&self, cx: &mut ViewContext<Self>) -> impl Element {
- let commands = SlashCommandRegistry::global(cx);
- let active_editor_focus_handle = self.workspace.upgrade().and_then(|workspace| {
- Some(
- workspace
- .read(cx)
- .active_item_as::<Editor>(cx)?
- .focus_handle(cx),
- )
- });
- let active_context_editor = self.active_context_editor.clone();
-
- PopoverMenu::new("inject-context-menu")
- .trigger(IconButton::new("trigger", IconName::Quote).tooltip(|cx| {
- Tooltip::with_meta("Insert Context", None, "Type / to insert via keyboard", cx)
- }))
- .menu(move |cx| {
- let active_context_editor = active_context_editor.clone()?;
- ContextMenu::build(cx, |mut menu, _cx| {
- for command_name in commands.featured_command_names() {
- if let Some(command) = commands.command(&command_name) {
- let menu_text = SharedString::from(Arc::from(command.menu_text()));
- menu = menu.custom_entry(
- {
- let command_name = command_name.clone();
- move |_cx| {
- h_flex()
- .gap_4()
- .w_full()
- .justify_between()
- .child(Label::new(menu_text.clone()))
- .child(
- Label::new(format!("/{command_name}"))
- .color(Color::Muted),
- )
- .into_any()
- }
- },
- {
- let active_context_editor = active_context_editor.clone();
- move |cx| {
- active_context_editor
- .update(cx, |context_editor, cx| {
- context_editor.insert_command(&command_name, cx)
- })
- .ok();
- }
- },
- )
- }
- }
-
- if let Some(active_editor_focus_handle) = active_editor_focus_handle.clone() {
- menu = menu
- .context(active_editor_focus_handle)
- .action("Quote Selection", Box::new(QuoteSelection));
- }
-
- menu
- })
- .into()
- })
- }
-
fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
let context = &self
.active_context_editor
@@ -4062,24 +4199,16 @@ impl ContextEditorToolbarItem {
impl Render for ContextEditorToolbarItem {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let left_side = h_flex()
+ .pl_1()
.gap_2()
.flex_1()
.min_w(rems(DEFAULT_TAB_TITLE.len() as f32))
.when(self.active_context_editor.is_some(), |left_side| {
- left_side
- .child(
- IconButton::new("regenerate-context", IconName::ArrowCircle)
- .visible_on_hover("toolbar")
- .tooltip(|cx| Tooltip::text("Regenerate Summary", cx))
- .on_click(cx.listener(move |_, _, cx| {
- cx.emit(ContextEditorToolbarItemEvent::RegenerateSummary)
- })),
- )
- .child(self.model_summary_editor.clone())
+ left_side.child(self.model_summary_editor.clone())
});
let active_provider = LanguageModelRegistry::read_global(cx).active_provider();
let active_model = LanguageModelRegistry::read_global(cx).active_model();
-
+ let weak_self = cx.view().downgrade();
let right_side = h_flex()
.gap_2()
.child(
@@ -4100,7 +4229,7 @@ impl Render for ContextEditorToolbarItem {
(Some(provider), Some(model)) => h_flex()
.gap_1()
.child(
- Icon::new(provider.icon())
+ Icon::new(model.icon().unwrap_or_else(|| provider.icon()))
.color(Color::Muted)
.size(IconSize::XSmall),
)
@@ -4129,7 +4258,70 @@ impl Render for ContextEditorToolbarItem {
.with_handle(self.model_selector_menu_handle.clone()),
)
.children(self.render_remaining_tokens(cx))
- .child(self.render_inject_context_menu(cx));
+ .child(
+ PopoverMenu::new("context-editor-popover")
+ .trigger(
+ IconButton::new("context-editor-trigger", IconName::EllipsisVertical)
+ .icon_size(IconSize::Small)
+ .tooltip(|cx| Tooltip::text("Open Context Options", cx)),
+ )
+ .menu({
+ let weak_self = weak_self.clone();
+ move |cx| {
+ let weak_self = weak_self.clone();
+ Some(ContextMenu::build(cx, move |menu, cx| {
+ let context = weak_self
+ .update(cx, |this, cx| {
+ active_editor_focus_handle(&this.workspace, cx)
+ })
+ .ok()
+ .flatten();
+ menu.when_some(context, |menu, context| menu.context(context))
+ .entry("Regenerate Context Title", None, {
+ let weak_self = weak_self.clone();
+ move |cx| {
+ weak_self
+ .update(cx, |_, cx| {
+ cx.emit(ContextEditorToolbarItemEvent::RegenerateSummary)
+ })
+ .ok();
+ }
+ })
+ .custom_entry(
+ |_| {
+ h_flex()
+ .w_full()
+ .justify_between()
+ .gap_2()
+ .child(Label::new("Insert Context"))
+ .child(Label::new("/ command").color(Color::Muted))
+ .into_any()
+ },
+ {
+ let weak_self = weak_self.clone();
+ move |cx| {
+ weak_self
+ .update(cx, |this, cx| {
+ if let Some(editor) =
+ &this.active_context_editor
+ {
+ editor
+ .update(cx, |this, cx| {
+ this.slash_menu_handle
+ .toggle(cx);
+ })
+ .ok();
+ }
+ })
+ .ok();
+ }
+ },
+ )
+ .action("Insert Selection", QuoteSelection.boxed_clone())
+ }))
+ }
+ }),
+ );
h_flex()
.size_full()
@@ -295,6 +295,7 @@ pub enum ContextEvent {
output_range: Range<language::Anchor>,
sections: Vec<SlashCommandOutputSection<language::Anchor>>,
run_commands_in_output: bool,
+ expand_result: bool,
},
Operation(ContextOperation),
}
@@ -774,6 +775,7 @@ impl Context {
cx.emit(ContextEvent::SlashCommandFinished {
output_range,
sections,
+ expand_result: false,
run_commands_in_output: false,
});
}
@@ -1395,7 +1397,8 @@ impl Context {
&mut self,
command_range: Range<language::Anchor>,
output: Task<Result<SlashCommandOutput>>,
- insert_trailing_newline: bool,
+ ensure_trailing_newline: bool,
+ expand_result: bool,
cx: &mut ModelContext<Self>,
) {
self.reparse_slash_commands(cx);
@@ -1406,8 +1409,27 @@ impl Context {
let output = output.await;
this.update(&mut cx, |this, cx| match output {
Ok(mut output) => {
- if insert_trailing_newline {
- output.text.push('\n');
+ // Ensure section ranges are valid.
+ for section in &mut output.sections {
+ section.range.start = section.range.start.min(output.text.len());
+ section.range.end = section.range.end.min(output.text.len());
+ while !output.text.is_char_boundary(section.range.start) {
+ section.range.start -= 1;
+ }
+ while !output.text.is_char_boundary(section.range.end) {
+ section.range.end += 1;
+ }
+ }
+
+ // Ensure there is a newline after the last section.
+ if ensure_trailing_newline {
+ let has_newline_after_last_section =
+ output.sections.last().map_or(false, |last_section| {
+ output.text[last_section.range.end..].ends_with('\n')
+ });
+ if !has_newline_after_last_section {
+ output.text.push('\n');
+ }
}
let version = this.version.clone();
@@ -1450,6 +1472,7 @@ impl Context {
output_range,
sections,
run_commands_in_output: output.run_commands_in_text,
+ expand_result,
},
)
});
@@ -473,7 +473,7 @@ async fn test_slash_commands(cx: &mut TestAppContext) {
}
#[gpui::test]
-async fn test_edit_step_parsing(cx: &mut TestAppContext) {
+async fn test_workflow_step_parsing(cx: &mut TestAppContext) {
cx.update(prompt_library::init);
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
@@ -891,6 +891,7 @@ async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: Std
run_commands_in_text: false,
})),
true,
+ false,
cx,
);
});
@@ -28,7 +28,7 @@ use gpui::{
FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
UpdateGlobal, View, ViewContext, WeakView, WindowContext,
};
-use language::{Buffer, IndentKind, Point, TransactionId};
+use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
@@ -38,6 +38,7 @@ use rope::Rope;
use settings::Settings;
use smol::future::FutureExt;
use std::{
+ cmp,
future::{self, Future},
mem,
ops::{Range, RangeInclusive},
@@ -46,7 +47,6 @@ use std::{
task::{self, Poll},
time::{Duration, Instant},
};
-use text::OffsetRangeExt as _;
use theme::ThemeSettings;
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt};
@@ -140,81 +140,66 @@ impl InlineAssistant {
cx: &mut WindowContext,
) {
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
- struct CodegenRange {
- transform_range: Range<Point>,
- selection_ranges: Vec<Range<Point>>,
- focus_assist: bool,
- }
- let newest_selection_range = editor.read(cx).selections.newest::<Point>(cx).range();
- let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
-
- let selection_ranges = snapshot
- .split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
- .map(|range| range.to_point(&snapshot))
- .collect::<Vec<Range<Point>>>();
-
- for selection_range in selection_ranges {
- let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
- let mut transform_range = selection_range.start..selection_range.end;
-
- // Expand the transform range to start/end of lines.
- // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
- transform_range.start.column = 0;
- if transform_range.end.column == 0 && transform_range.end > transform_range.start {
- transform_range.end.row -= 1;
- }
- transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
- let selection_range =
- selection_range.start..selection_range.end.min(transform_range.end);
-
- // If we intersect the previous transform range,
- if let Some(CodegenRange {
- transform_range: prev_transform_range,
- selection_ranges,
- focus_assist,
- }) = codegen_ranges.last_mut()
- {
- if transform_range.start <= prev_transform_range.end {
- prev_transform_range.end = transform_range.end;
- selection_ranges.push(selection_range);
- *focus_assist |= selection_is_newest;
+ let mut selections = Vec::<Selection<Point>>::new();
+ let mut newest_selection = None;
+ for mut selection in editor.read(cx).selections.all::<Point>(cx) {
+ if selection.end > selection.start {
+ selection.start.column = 0;
+ // If the selection ends at the start of the line, we don't want to include it.
+ if selection.end.column == 0 {
+ selection.end.row -= 1;
+ }
+ selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
+ }
+
+ if let Some(prev_selection) = selections.last_mut() {
+ if selection.start <= prev_selection.end {
+ prev_selection.end = selection.end;
continue;
}
}
- codegen_ranges.push(CodegenRange {
- transform_range,
- selection_ranges: vec![selection_range],
- focus_assist: selection_is_newest,
- })
+ let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
+ if selection.id > latest_selection.id {
+ *latest_selection = selection.clone();
+ }
+ selections.push(selection);
+ }
+ let newest_selection = newest_selection.unwrap();
+
+ let mut codegen_ranges = Vec::new();
+ for (excerpt_id, buffer, buffer_range) in
+ snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
+ snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
+ }))
+ {
+ let start = Anchor {
+ buffer_id: Some(buffer.remote_id()),
+ excerpt_id,
+ text_anchor: buffer.anchor_before(buffer_range.start),
+ };
+ let end = Anchor {
+ buffer_id: Some(buffer.remote_id()),
+ excerpt_id,
+ text_anchor: buffer.anchor_after(buffer_range.end),
+ };
+ codegen_ranges.push(start..end);
}
let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
+
let mut assists = Vec::new();
let mut assist_to_focus = None;
-
- for CodegenRange {
- transform_range,
- selection_ranges,
- focus_assist,
- } in codegen_ranges
- {
- let transform_range = snapshot.anchor_before(transform_range.start)
- ..snapshot.anchor_after(transform_range.end);
- let selection_ranges = selection_ranges
- .iter()
- .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
- .collect::<Vec<_>>();
-
+ for range in codegen_ranges {
+ let assist_id = self.next_assist_id.post_inc();
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
- transform_range.clone(),
- selection_ranges,
+ range.clone(),
None,
self.telemetry.clone(),
self.prompt_builder.clone(),
@@ -222,7 +207,6 @@ impl InlineAssistant {
)
});
- let assist_id = self.next_assist_id.post_inc();
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
PromptEditor::new(
@@ -239,16 +223,23 @@ impl InlineAssistant {
)
});
- if focus_assist {
- assist_to_focus = Some(assist_id);
+ if assist_to_focus.is_none() {
+ let focus_assist = if newest_selection.reversed {
+ range.start.to_point(&snapshot) == newest_selection.start
+ } else {
+ range.end.to_point(&snapshot) == newest_selection.end
+ };
+ if focus_assist {
+ assist_to_focus = Some(assist_id);
+ }
}
let [prompt_block_id, end_block_id] =
- self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
+ self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push((
assist_id,
- transform_range,
+ range,
prompt_editor,
prompt_block_id,
end_block_id,
@@ -315,7 +306,6 @@ impl InlineAssistant {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
- vec![range.clone()],
initial_transaction_id,
self.telemetry.clone(),
self.prompt_builder.clone(),
@@ -925,7 +915,12 @@ impl InlineAssistant {
assist
.codegen
.update(cx, |codegen, cx| {
- codegen.start(user_prompt, assistant_panel_context, cx)
+ codegen.start(
+ assist.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
})
.log_err();
@@ -2120,9 +2115,12 @@ impl InlineAssist {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
- self.codegen
- .read(cx)
- .count_tokens(user_prompt, assistant_panel_context, cx)
+ self.codegen.read(cx).count_tokens(
+ self.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
}
}
@@ -2143,8 +2141,6 @@ pub struct Codegen {
buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot,
- transform_range: Range<Anchor>,
- selected_ranges: Vec<Range<Anchor>>,
edit_position: Option<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>,
initial_transaction_id: Option<TransactionId>,
@@ -2154,7 +2150,7 @@ pub struct Codegen {
diff: Diff,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
- prompt_builder: Arc<PromptBuilder>,
+ builder: Arc<PromptBuilder>,
}
enum CodegenStatus {
@@ -2181,8 +2177,7 @@ impl EventEmitter<CodegenEvent> for Codegen {}
impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
- transform_range: Range<Anchor>,
- selected_ranges: Vec<Range<Anchor>>,
+ range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
@@ -2192,7 +2187,7 @@ impl Codegen {
let (old_buffer, _, _) = buffer
.read(cx)
- .range_to_buffer_ranges(transform_range.clone(), cx)
+ .range_to_buffer_ranges(range.clone(), cx)
.pop()
.unwrap();
let old_buffer = cx.new_model(|cx| {
@@ -2223,9 +2218,7 @@ impl Codegen {
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_transaction_id,
- prompt_builder: builder,
- transform_range,
- selected_ranges,
+ builder,
}
}
@@ -2250,12 +2243,14 @@ impl Codegen {
pub fn count_tokens(
&self,
+ edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<TokenCounts>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
- let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
+ let request =
+ self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
match request {
Ok(request) => {
let total_count = model.count_tokens(request.clone(), cx);
@@ -2280,6 +2275,7 @@ impl Codegen {
pub fn start(
&mut self,
+ edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>,
@@ -2294,20 +2290,24 @@ impl Codegen {
});
}
- self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
+ self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
let telemetry_id = model.telemetry_id();
- let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
- if user_prompt.trim().to_lowercase() == "delete" {
- async { Ok(stream::empty().boxed()) }.boxed_local()
- } else {
- let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
+ let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
+ .trim()
+ .to_lowercase()
+ == "delete"
+ {
+ async { Ok(stream::empty().boxed()) }.boxed_local()
+ } else {
+ let request =
+ self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
- let chunks =
- cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
- async move { Ok(chunks.await?.boxed()) }.boxed_local()
- };
- self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
+ let chunks =
+ cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
+ async move { Ok(chunks.await?.boxed()) }.boxed_local()
+ };
+ self.handle_stream(telemetry_id, edit_range, chunks, cx);
Ok(())
}
@@ -2315,10 +2315,11 @@ impl Codegen {
&self,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
+ edit_range: Range<Anchor>,
cx: &AppContext,
) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx);
- let language = buffer.language_at(self.transform_range.start);
+ let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
@@ -2343,9 +2344,9 @@ impl Codegen {
};
let language_name = language_name.as_deref();
- let start = buffer.point_to_buffer_offset(self.transform_range.start);
- let end = buffer.point_to_buffer_offset(self.transform_range.end);
- let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) {
+ let start = buffer.point_to_buffer_offset(edit_range.start);
+ let end = buffer.point_to_buffer_offset(edit_range.end);
+ let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
@@ -2357,39 +2358,9 @@ impl Codegen {
return Err(anyhow::anyhow!("invalid transformation range"));
};
- let mut transform_context_range = transform_range.to_point(&transform_buffer);
- transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
- transform_context_range.start.column = 0;
- transform_context_range.end =
- (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
- transform_context_range.end.column =
- transform_buffer.line_len(transform_context_range.end.row);
- let transform_context_range = transform_context_range.to_offset(&transform_buffer);
-
- let selected_ranges = self
- .selected_ranges
- .iter()
- .filter_map(|selected_range| {
- let start = buffer
- .point_to_buffer_offset(selected_range.start)
- .map(|(_, offset)| offset)?;
- let end = buffer
- .point_to_buffer_offset(selected_range.end)
- .map(|(_, offset)| offset)?;
- Some(start..end)
- })
- .collect::<Vec<_>>();
-
let prompt = self
- .prompt_builder
- .generate_content_prompt(
- user_prompt,
- language_name,
- transform_buffer,
- transform_range,
- selected_ranges,
- transform_context_range,
- )
+ .builder
+ .generate_content_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut messages = Vec::new();
@@ -2462,19 +2433,84 @@ impl Codegen {
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
+ let mut new_text = String::new();
+ let mut base_indent = None;
+ let mut line_indent = None;
+ let mut first_line = true;
+
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
- let char_ops = diff.push_new(&chunk);
- line_diff.push_char_operations(&char_ops, &selected_text);
- diff_tx
- .send((char_ops, line_diff.line_operations()))
- .await?;
+
+ let mut lines = chunk.split('\n').peekable();
+ while let Some(line) = lines.next() {
+ new_text.push_str(line);
+ if line_indent.is_none() {
+ if let Some(non_whitespace_ch_ix) =
+ new_text.find(|ch: char| !ch.is_whitespace())
+ {
+ line_indent = Some(non_whitespace_ch_ix);
+ base_indent = base_indent.or(line_indent);
+
+ let line_indent = line_indent.unwrap();
+ let base_indent = base_indent.unwrap();
+ let indent_delta =
+ line_indent as i32 - base_indent as i32;
+ let mut corrected_indent_len = cmp::max(
+ 0,
+ suggested_line_indent.len as i32 + indent_delta,
+ )
+ as usize;
+ if first_line {
+ corrected_indent_len = corrected_indent_len
+ .saturating_sub(
+ selection_start.column as usize,
+ );
+ }
+
+ let indent_char = suggested_line_indent.char();
+ let mut indent_buffer = [0; 4];
+ let indent_str =
+ indent_char.encode_utf8(&mut indent_buffer);
+ new_text.replace_range(
+ ..line_indent,
+ &indent_str.repeat(corrected_indent_len),
+ );
+ }
+ }
+
+ if line_indent.is_some() {
+ let char_ops = diff.push_new(&new_text);
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ new_text.clear();
+ }
+
+ if lines.peek().is_some() {
+ let char_ops = diff.push_new("\n");
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ if line_indent.is_none() {
+ // Don't write out the leading indentation in empty lines on the next line
+ // This is the case where the above if statement didn't clear the buffer
+ new_text.clear();
+ }
+ line_indent = None;
+ first_line = false;
+ }
+ }
}
- let char_ops = diff.finish();
+ let mut char_ops = diff.push_new(&new_text);
+ char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
@@ -2938,13 +2974,311 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use futures::stream::{self};
+ use gpui::{Context, TestAppContext};
+ use indoc::indoc;
+ use language::{
+ language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
+ Point,
+ };
+ use language_model::LanguageModelRegistry;
+ use rand::prelude::*;
use serde::Serialize;
+ use settings::SettingsStore;
+ use std::{future, sync::Arc};
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
+ #[gpui::test(iterations = 10)]
+ async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_model::LanguageModelRegistry::test);
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ let x = 0;
+ for _ in 0..10 {
+ x += 1;
+ }
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range,
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ let mut new_text = concat!(
+ " let mut x = 0;\n",
+ " while x < 10 {\n",
+ " x += 1;\n",
+ " }",
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_past_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ le
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "t mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_before_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = concat!(
+ "fn main() {\n",
+ " \n",
+ "}\n" //
+ );
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "let mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ func main() {
+ \tx := 0
+ \tfor i := 0; i < 10; i++ {
+ \t\tx++
+ \t}
+ }
+ "};
+ let buffer = cx.new_model(|cx| Buffer::local(text, cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ let new_text = concat!(
+ "func main() {\n",
+ "\tx := 0\n",
+ "\tfor x < 10 {\n",
+ "\t\tx++\n",
+ "\t}", //
+ );
+ chunks_tx.unbounded_send(new_text.to_string()).unwrap();
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ func main() {
+ \tx := 0
+ \tfor x < 10 {
+ \t\tx++
+ \t}
+ }
+ "}
+ );
+ }
+
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@@ -2984,4 +3318,27 @@ mod tests {
)
}
}
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::language()),
+ )
+ .with_indents_query(
+ r#"
+ (call_expression) @indent
+ (field_expression) @indent
+ (_ "(" ")" @end) @indent
+ (_ "{" "}" @end) @indent
+ "#,
+ )
+ .unwrap()
+ }
}
@@ -1,15 +1,15 @@
use feature_flags::ZedPro;
+use gpui::Action;
use gpui::DismissEvent;
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
use proto::Plan;
+use workspace::ShowConfiguration;
use std::sync::Arc;
use ui::ListItemSpacing;
use crate::assistant_settings::AssistantSettings;
-use crate::ShowConfiguration;
use fs::Fs;
-use gpui::Action;
use gpui::SharedString;
use gpui::Task;
use picker::{Picker, PickerDelegate};
@@ -36,7 +36,7 @@ pub struct ModelPickerDelegate {
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
- provider_icon: IconName,
+ icon: IconName,
availability: LanguageModelAvailability,
is_selected: bool,
}
@@ -156,7 +156,7 @@ impl PickerDelegate for ModelPickerDelegate {
.selected(selected)
.start_slot(
div().pr_1().child(
- Icon::new(model_info.provider_icon)
+ Icon::new(model_info.icon)
.color(Color::Muted)
.size(IconSize::Medium),
),
@@ -261,16 +261,17 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
.iter()
.flat_map(|provider| {
let provider_id = provider.id();
- let provider_icon = provider.icon();
+ let icon = provider.icon();
let selected_model = selected_model.clone();
let selected_provider = selected_provider.clone();
provider.provided_models(cx).into_iter().map(move |model| {
let model = model.clone();
+ let icon = model.icon().unwrap_or(icon);
ModelInfo {
model: model.clone(),
- provider_icon,
+ icon,
availability: model.availability(),
is_selected: selected_model.as_ref() == Some(&model.id())
&& selected_provider.as_ref() == Some(&provider_id),
@@ -1,26 +1,24 @@
+use anyhow::Result;
use assets::Assets;
use fs::Fs;
use futures::StreamExt;
-use handlebars::{Handlebars, RenderError, TemplateError};
+use gpui::AssetSource;
+use handlebars::{Handlebars, RenderError};
use language::BufferSnapshot;
use parking_lot::Mutex;
use serde::Serialize;
-use std::{ops::Range, sync::Arc, time::Duration};
+use std::{ops::Range, path::PathBuf, sync::Arc, time::Duration};
use util::ResultExt;
#[derive(Serialize)]
pub struct ContentPromptContext {
pub content_type: String,
pub language_name: Option<String>,
+ pub is_insert: bool,
pub is_truncated: bool,
pub document_content: String,
pub user_prompt: String,
- pub rewrite_section: String,
- pub rewrite_section_prefix: String,
- pub rewrite_section_suffix: String,
- pub rewrite_section_with_edits: String,
- pub has_insertion: bool,
- pub has_replacement: bool,
+ pub rewrite_section: Option<String>,
}
#[derive(Serialize)]
@@ -42,128 +40,162 @@ pub struct StepResolutionContext {
pub step_to_resolve: String,
}
-pub struct PromptBuilder {
- handlebars: Arc<Mutex<Handlebars<'static>>>,
+pub struct PromptLoadingParams<'a> {
+ pub fs: Arc<dyn Fs>,
+ pub repo_path: Option<PathBuf>,
+ pub cx: &'a gpui::AppContext,
}
-pub struct PromptOverrideContext<'a> {
- pub dev_mode: bool,
- pub fs: Arc<dyn Fs>,
- pub cx: &'a mut gpui::AppContext,
+pub struct PromptBuilder {
+ handlebars: Arc<Mutex<Handlebars<'static>>>,
}
impl PromptBuilder {
- pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
+ pub fn new(loading_params: Option<PromptLoadingParams>) -> Result<Self> {
let mut handlebars = Handlebars::new();
- Self::register_templates(&mut handlebars)?;
+ Self::register_built_in_templates(&mut handlebars)?;
let handlebars = Arc::new(Mutex::new(handlebars));
- if let Some(override_cx) = override_cx {
- Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
+ if let Some(params) = loading_params {
+ Self::watch_fs_for_template_overrides(params, handlebars.clone());
}
Ok(Self { handlebars })
}
+ /// Watches the filesystem for changes to prompt template overrides.
+ ///
+ /// This function sets up a file watcher on the prompt templates directory. It performs
+ /// an initial scan of the directory and registers any existing template overrides.
+ /// Then it continuously monitors for changes, reloading templates as they are
+ /// modified or added.
+ ///
+ /// If the templates directory doesn't exist initially, it waits for it to be created.
+ /// If the directory is removed, it restores the built-in templates and waits for the
+ /// directory to be recreated.
+ ///
+ /// # Arguments
+ ///
+ /// * `params` - A `PromptLoadingParams` struct containing the filesystem, repository path,
+ /// and application context.
+ /// * `handlebars` - An `Arc<Mutex<Handlebars>>` for registering and updating templates.
fn watch_fs_for_template_overrides(
- PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
+ mut params: PromptLoadingParams,
handlebars: Arc<Mutex<Handlebars<'static>>>,
) {
- cx.background_executor()
+ params.repo_path = None;
+ let templates_dir = paths::prompt_overrides_dir(params.repo_path.as_deref());
+ params.cx.background_executor()
.spawn(async move {
- let templates_dir = if dev_mode {
- std::env::current_dir()
- .ok()
- .and_then(|pwd| {
- let pwd_assets_prompts = pwd.join("assets").join("prompts");
- pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
- })
- .unwrap_or_else(|| paths::prompt_overrides_dir().clone())
- } else {
- paths::prompt_overrides_dir().clone()
+ let Some(parent_dir) = templates_dir.parent() else {
+ return;
};
- // Create the prompt templates directory if it doesn't exist
- if !fs.is_dir(&templates_dir).await {
- if let Err(e) = fs.create_dir(&templates_dir).await {
- log::error!("Failed to create prompt templates directory: {}", e);
- return;
+ let mut found_dir_once = false;
+ loop {
+ // Check if the templates directory exists and handle its status
+ // If it exists, log its presence and check if it's a symlink
+ // If it doesn't exist:
+ // - Log that we're using built-in prompts
+ // - Check if it's a broken symlink and log if so
+ // - Set up a watcher to detect when it's created
+ // After the first check, set the `found_dir_once` flag
+ // This allows us to avoid logging when looping back around after deleting the prompt overrides directory.
+ let dir_status = params.fs.is_dir(&templates_dir).await;
+ let symlink_status = params.fs.read_link(&templates_dir).await.ok();
+ if dir_status {
+ let mut log_message = format!("Prompt template overrides directory found at {}", templates_dir.display());
+ if let Some(target) = symlink_status {
+ log_message.push_str(" -> ");
+ log_message.push_str(&target.display().to_string());
+ }
+ log::info!("{}.", log_message);
+ } else {
+ if !found_dir_once {
+ log::info!("No prompt template overrides directory found at {}. Using built-in prompts.", templates_dir.display());
+ if let Some(target) = symlink_status {
+ log::info!("Symlink found pointing to {}, but target is invalid.", target.display());
+ }
+ }
+
+ if params.fs.is_dir(parent_dir).await {
+ let (mut changes, _watcher) = params.fs.watch(parent_dir, Duration::from_secs(1)).await;
+ while let Some(changed_paths) = changes.next().await {
+ if changed_paths.iter().any(|p| p == &templates_dir) {
+ let mut log_message = format!("Prompt template overrides directory detected at {}", templates_dir.display());
+ if let Ok(target) = params.fs.read_link(&templates_dir).await {
+ log_message.push_str(" -> ");
+ log_message.push_str(&target.display().to_string());
+ }
+ log::info!("{}.", log_message);
+ break;
+ }
+ }
+ } else {
+ return;
+ }
}
- }
- // Initial scan of the prompts directory
- if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
- while let Some(Ok(file_path)) = entries.next().await {
- if file_path.to_string_lossy().ends_with(".hbs") {
- if let Ok(content) = fs.load(&file_path).await {
- let file_name = file_path.file_stem().unwrap().to_string_lossy();
+ found_dir_once = true;
- match handlebars.lock().register_template_string(&file_name, content) {
- Ok(_) => {
- log::info!(
- "Successfully registered template override: {} ({})",
- file_name,
- file_path.display()
- );
- },
- Err(e) => {
- log::error!(
- "Failed to register template during initial scan: {} ({})",
- e,
- file_path.display()
- );
- },
+ // Initial scan of the prompt overrides directory
+ if let Ok(mut entries) = params.fs.read_dir(&templates_dir).await {
+ while let Some(Ok(file_path)) = entries.next().await {
+ if file_path.to_string_lossy().ends_with(".hbs") {
+ if let Ok(content) = params.fs.load(&file_path).await {
+ let file_name = file_path.file_stem().unwrap().to_string_lossy();
+ log::info!("Registering prompt template override: {}", file_name);
+ handlebars.lock().register_template_string(&file_name, content).log_err();
}
}
}
}
- }
- // Watch for changes
- let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
- while let Some(changed_paths) = changes.next().await {
- for changed_path in changed_paths {
- if changed_path.extension().map_or(false, |ext| ext == "hbs") {
- log::info!("Reloading template: {}", changed_path.display());
- if let Some(content) = fs.load(&changed_path).await.log_err() {
- let file_name = changed_path.file_stem().unwrap().to_string_lossy();
- let file_path = changed_path.to_string_lossy();
- match handlebars.lock().register_template_string(&file_name, content) {
- Ok(_) => log::info!(
- "Successfully reloaded template: {} ({})",
- file_name,
- file_path
- ),
- Err(e) => log::error!(
- "Failed to register template: {} ({})",
- e,
- file_path
- ),
+ // Watch both the parent directory and the template overrides directory:
+ // - Monitor the parent directory to detect if the template overrides directory is deleted.
+ // - Monitor the template overrides directory to re-register templates when they change.
+ // Combine both watch streams into a single stream.
+ let (parent_changes, parent_watcher) = params.fs.watch(parent_dir, Duration::from_secs(1)).await;
+ let (changes, watcher) = params.fs.watch(&templates_dir, Duration::from_secs(1)).await;
+ let mut combined_changes = futures::stream::select(changes, parent_changes);
+
+ while let Some(changed_paths) = combined_changes.next().await {
+ if changed_paths.iter().any(|p| p == &templates_dir) {
+ if !params.fs.is_dir(&templates_dir).await {
+ log::info!("Prompt template overrides directory removed. Restoring built-in prompt templates.");
+ Self::register_built_in_templates(&mut handlebars.lock()).log_err();
+ break;
+ }
+ }
+ for changed_path in changed_paths {
+ if changed_path.starts_with(&templates_dir) && changed_path.extension().map_or(false, |ext| ext == "hbs") {
+ log::info!("Reloading prompt template override: {}", changed_path.display());
+ if let Some(content) = params.fs.load(&changed_path).await.log_err() {
+ let file_name = changed_path.file_stem().unwrap().to_string_lossy();
+ handlebars.lock().register_template_string(&file_name, content).log_err();
}
}
}
}
+
+ drop(watcher);
+ drop(parent_watcher);
}
- drop(watcher);
})
.detach();
}
- fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> {
- let mut register_template = |id: &str| {
- let prompt = Assets::get(&format!("prompts/{}.hbs", id))
- .unwrap_or_else(|| panic!("{} prompt template not found", id))
- .data;
- handlebars
- .register_template_string(id, String::from_utf8_lossy(&prompt))
- .map_err(Box::new)
- };
-
- register_template("content_prompt")?;
- register_template("terminal_assistant_prompt")?;
- register_template("edit_workflow")?;
- register_template("step_resolution")?;
+ fn register_built_in_templates(handlebars: &mut Handlebars) -> Result<()> {
+ for path in Assets.list("prompts")? {
+ if let Some(id) = path.split('/').last().and_then(|s| s.strip_suffix(".hbs")) {
+ if let Some(prompt) = Assets.load(path.as_ref()).log_err().flatten() {
+ log::info!("Registering built-in prompt template: {}", id);
+ handlebars
+ .register_template_string(id, String::from_utf8_lossy(prompt.as_ref()))?
+ }
+ }
+ }
Ok(())
}
@@ -173,9 +205,7 @@ impl PromptBuilder {
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
- transform_range: Range<usize>,
- selected_ranges: Vec<Range<usize>>,
- transform_context_range: Range<usize>,
+ range: Range<usize>,
) -> Result<String, RenderError> {
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text",
@@ -183,20 +213,21 @@ impl PromptBuilder {
};
const MAX_CTX: usize = 50000;
+ let is_insert = range.is_empty();
let mut is_truncated = false;
- let before_range = 0..transform_range.start;
+ let before_range = 0..range.start;
let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true;
- transform_range.start - MAX_CTX..transform_range.start
+ range.start - MAX_CTX..range.start
} else {
before_range
};
- let after_range = transform_range.end..buffer.len();
+ let after_range = range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true;
- transform_range.end..transform_range.end + MAX_CTX
+ range.end..range.end + MAX_CTX
} else {
after_range
};
@@ -205,74 +236,37 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
-
- document_content.push_str("<rewrite_this>\n");
- for chunk in buffer.text_for_range(transform_range.clone()) {
- document_content.push_str(chunk);
+ if is_insert {
+ document_content.push_str("<insert_here></insert_here>");
+ } else {
+ document_content.push_str("<rewrite_this>\n");
+ for chunk in buffer.text_for_range(range.clone()) {
+ document_content.push_str(chunk);
+ }
+ document_content.push_str("\n</rewrite_this>");
}
- document_content.push_str("\n</rewrite_this>");
-
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
- let mut rewrite_section = String::new();
- for chunk in buffer.text_for_range(transform_range.clone()) {
- rewrite_section.push_str(chunk);
- }
-
- let mut rewrite_section_prefix = String::new();
- for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
- rewrite_section_prefix.push_str(chunk);
- }
-
- let mut rewrite_section_suffix = String::new();
- for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
- rewrite_section_suffix.push_str(chunk);
- }
-
- let rewrite_section_with_edits = {
- let mut section_with_selections = String::new();
- let mut last_end = 0;
- for selected_range in &selected_ranges {
- if selected_range.start > last_end {
- section_with_selections.push_str(
- &rewrite_section[last_end..selected_range.start - transform_range.start],
- );
- }
- if selected_range.start == selected_range.end {
- section_with_selections.push_str("<insert_here></insert_here>");
- } else {
- section_with_selections.push_str("<edit_here>");
- section_with_selections.push_str(
- &rewrite_section[selected_range.start - transform_range.start
- ..selected_range.end - transform_range.start],
- );
- section_with_selections.push_str("</edit_here>");
- }
- last_end = selected_range.end - transform_range.start;
- }
- if last_end < rewrite_section.len() {
- section_with_selections.push_str(&rewrite_section[last_end..]);
+ let rewrite_section = if !is_insert {
+ let mut section = String::new();
+ for chunk in buffer.text_for_range(range.clone()) {
+ section.push_str(chunk);
}
- section_with_selections
+ Some(section)
+ } else {
+ None
};
- let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
- let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
-
let context = ContentPromptContext {
content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()),
+ is_insert,
is_truncated,
document_content,
user_prompt,
rewrite_section,
- rewrite_section_prefix,
- rewrite_section_suffix,
- rewrite_section_with_edits,
- has_insertion,
- has_replacement,
};
self.handlebars.lock().render("content_prompt", &context)
@@ -124,6 +124,7 @@ impl SlashCommandCompletionProvider {
&command_name,
&[],
true,
+ false,
workspace.clone(),
cx,
);
@@ -208,6 +209,7 @@ impl SlashCommandCompletionProvider {
&command_name,
&completed_arguments,
true,
+ false,
workspace.clone(),
cx,
);
@@ -67,7 +67,11 @@ impl SlashCommand for ContextServerSlashCommand {
) -> Task<Result<SlashCommandOutput>> {
let server_id = self.server_id.clone();
let prompt_name = self.prompt.name.clone();
- let argument = arguments.first().cloned();
+
+ let prompt_args = match prompt_arguments(&self.prompt, arguments) {
+ Ok(args) => args,
+ Err(e) => return Task::ready(Err(e)),
+ };
let manager = ContextServerManager::global(cx);
let manager = manager.read(cx);
@@ -76,10 +80,7 @@ impl SlashCommand for ContextServerSlashCommand {
let Some(protocol) = server.client.read().clone() else {
return Err(anyhow!("Context server not initialized"));
};
-
- let result = protocol
- .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
- .await?;
+ let result = protocol.run_prompt(&prompt_name, prompt_args).await?;
Ok(SlashCommandOutput {
sections: vec![SlashCommandOutputSection {
@@ -97,19 +98,27 @@ impl SlashCommand for ContextServerSlashCommand {
}
}
-fn prompt_arguments(
- prompt: &PromptInfo,
- argument: Option<String>,
-) -> Result<HashMap<String, String>> {
+fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
match &prompt.arguments {
- Some(args) if args.len() >= 2 => Err(anyhow!(
+ Some(args) if args.len() > 1 => Err(anyhow!(
"Prompt has more than one argument, which is not supported"
)),
- Some(args) if args.len() == 1 => match argument {
- Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
- None => Err(anyhow!("Prompt expects argument but none given")),
- },
- Some(_) | None => Ok(HashMap::default()),
+ Some(args) if args.len() == 1 => {
+ if !arguments.is_empty() {
+ let mut map = HashMap::default();
+ map.insert(args[0].name.clone(), arguments.join(" "));
+ Ok(map)
+ } else {
+ Err(anyhow!("Prompt expects argument but none given"))
+ }
+ }
+ Some(_) | None => {
+ if arguments.is_empty() {
+ Ok(HashMap::default())
+ } else {
+ Err(anyhow!("Prompt expects no arguments but some were given"))
+ }
+ }
}
}
@@ -0,0 +1,306 @@
+use std::sync::Arc;
+
+use assistant_slash_command::SlashCommandRegistry;
+use gpui::AnyElement;
+use gpui::DismissEvent;
+use gpui::WeakView;
+use picker::PickerEditorPosition;
+
+use ui::ListItemSpacing;
+
+use gpui::SharedString;
+use gpui::Task;
+use picker::{Picker, PickerDelegate};
+use ui::{prelude::*, ListItem, PopoverMenu, PopoverTrigger};
+
+use crate::assistant_panel::ContextEditor;
+
+#[derive(IntoElement)]
+pub(super) struct SlashCommandSelector<T: PopoverTrigger> {
+ registry: Arc<SlashCommandRegistry>,
+ active_context_editor: WeakView<ContextEditor>,
+ trigger: T,
+}
+
+#[derive(Clone)]
+struct SlashCommandInfo {
+ name: SharedString,
+ description: SharedString,
+ args: Option<SharedString>,
+}
+
+#[derive(Clone)]
+enum SlashCommandEntry {
+ Info(SlashCommandInfo),
+ Advert {
+ name: SharedString,
+ renderer: fn(&mut WindowContext<'_>) -> AnyElement,
+ on_confirm: fn(&mut WindowContext<'_>),
+ },
+}
+
+impl AsRef<str> for SlashCommandEntry {
+ fn as_ref(&self) -> &str {
+ match self {
+ SlashCommandEntry::Info(SlashCommandInfo { name, .. })
+ | SlashCommandEntry::Advert { name, .. } => name,
+ }
+ }
+}
+
+pub(crate) struct SlashCommandDelegate {
+ all_commands: Vec<SlashCommandEntry>,
+ filtered_commands: Vec<SlashCommandEntry>,
+ active_context_editor: WeakView<ContextEditor>,
+ selected_index: usize,
+}
+
+impl<T: PopoverTrigger> SlashCommandSelector<T> {
+ pub(crate) fn new(
+ registry: Arc<SlashCommandRegistry>,
+ active_context_editor: WeakView<ContextEditor>,
+ trigger: T,
+ ) -> Self {
+ SlashCommandSelector {
+ registry,
+ active_context_editor,
+ trigger,
+ }
+ }
+}
+
+impl PickerDelegate for SlashCommandDelegate {
+ type ListItem = ListItem;
+
+ fn match_count(&self) -> usize {
+ self.filtered_commands.len()
+ }
+
+ fn selected_index(&self) -> usize {
+ self.selected_index
+ }
+
+ fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
+ self.selected_index = ix.min(self.filtered_commands.len().saturating_sub(1));
+ cx.notify();
+ }
+
+ fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
+ "Select a command...".into()
+ }
+
+ fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
+ let all_commands = self.all_commands.clone();
+ cx.spawn(|this, mut cx| async move {
+ let filtered_commands = cx
+ .background_executor()
+ .spawn(async move {
+ if query.is_empty() {
+ all_commands
+ } else {
+ all_commands
+ .into_iter()
+ .filter(|model_info| {
+ model_info
+ .as_ref()
+ .to_lowercase()
+ .contains(&query.to_lowercase())
+ })
+ .collect()
+ }
+ })
+ .await;
+
+ this.update(&mut cx, |this, cx| {
+ this.delegate.filtered_commands = filtered_commands;
+ this.delegate.set_selected_index(0, cx);
+ cx.notify();
+ })
+ .ok();
+ })
+ }
+
+ fn separators_after_indices(&self) -> Vec<usize> {
+ let mut ret = vec![];
+ let mut previous_is_advert = false;
+
+ for (index, command) in self.filtered_commands.iter().enumerate() {
+ if previous_is_advert {
+ if let SlashCommandEntry::Info(_) = command {
+ previous_is_advert = false;
+ debug_assert_ne!(
+ index, 0,
+ "index cannot be zero, as we can never have a separator at 0th position"
+ );
+ ret.push(index - 1);
+ }
+ } else {
+ if let SlashCommandEntry::Advert { .. } = command {
+ previous_is_advert = true;
+ if index != 0 {
+ ret.push(index - 1);
+ }
+ }
+ }
+ }
+ ret
+ }
+ fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
+ if let Some(command) = self.filtered_commands.get(self.selected_index) {
+ if let SlashCommandEntry::Info(info) = command {
+ self.active_context_editor
+ .update(cx, |context_editor, cx| {
+ context_editor.insert_command(&info.name, cx)
+ })
+ .ok();
+ } else if let SlashCommandEntry::Advert { on_confirm, .. } = command {
+ on_confirm(cx);
+ }
+ cx.emit(DismissEvent);
+ }
+ }
+
+ fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
+
+ fn editor_position(&self) -> PickerEditorPosition {
+ PickerEditorPosition::End
+ }
+
+ fn render_match(
+ &self,
+ ix: usize,
+ selected: bool,
+ cx: &mut ViewContext<Picker<Self>>,
+ ) -> Option<Self::ListItem> {
+ let command_info = self.filtered_commands.get(ix)?;
+
+ match command_info {
+ SlashCommandEntry::Info(info) => Some(
+ ListItem::new(ix)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .selected(selected)
+ .child(
+ h_flex()
+ .group(format!("command-entry-label-{ix}"))
+ .w_full()
+ .min_w(px(220.))
+ .child(
+ v_flex()
+ .child(
+ h_flex()
+ .child(div().font_buffer(cx).child({
+ let mut label = format!("/{}", info.name);
+ if let Some(args) =
+ info.args.as_ref().filter(|_| selected)
+ {
+ label.push_str(&args);
+ }
+ Label::new(label).size(LabelSize::Small)
+ }))
+ .children(info.args.clone().filter(|_| !selected).map(
+ |args| {
+ div()
+ .font_buffer(cx)
+ .child(
+ Label::new(args).size(LabelSize::Small),
+ )
+ .visible_on_hover(format!(
+ "command-entry-label-{ix}"
+ ))
+ },
+ )),
+ )
+ .child(
+ Label::new(info.description.clone())
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ ),
+ ),
+ ),
+ SlashCommandEntry::Advert { renderer, .. } => Some(
+ ListItem::new(ix)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .selected(selected)
+ .child(renderer(cx)),
+ ),
+ }
+ }
+}
+
+impl<T: PopoverTrigger> RenderOnce for SlashCommandSelector<T> {
+ fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+ let all_models = self
+ .registry
+ .featured_command_names()
+ .into_iter()
+ .filter_map(|command_name| {
+ let command = self.registry.command(&command_name)?;
+ let menu_text = SharedString::from(Arc::from(command.menu_text()));
+ let label = command.label(cx);
+ let args = label.filter_range.end.ne(&label.text.len()).then(|| {
+ SharedString::from(
+ label.text[label.filter_range.end..label.text.len()].to_owned(),
+ )
+ });
+ Some(SlashCommandEntry::Info(SlashCommandInfo {
+ name: command_name.into(),
+ description: menu_text,
+ args,
+ }))
+ })
+ .chain([SlashCommandEntry::Advert {
+ name: "create-your-command".into(),
+ renderer: |cx| {
+ v_flex()
+ .child(
+ h_flex()
+ .font_buffer(cx)
+ .items_center()
+ .gap_1()
+ .child(div().font_buffer(cx).child(
+ Label::new("create-your-command").size(LabelSize::Small),
+ ))
+ .child(Icon::new(IconName::ArrowUpRight).size(IconSize::XSmall)),
+ )
+ .child(
+ Label::new("Learn how to create a custom command")
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element()
+ },
+ on_confirm: |cx| cx.open_url("https://zed.dev/docs/extensions/slash-commands"),
+ }])
+ .collect::<Vec<_>>();
+
+ let delegate = SlashCommandDelegate {
+ all_commands: all_models.clone(),
+ active_context_editor: self.active_context_editor.clone(),
+ filtered_commands: all_models,
+ selected_index: 0,
+ };
+
+ let picker_view = cx.new_view(|cx| {
+ let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
+ picker
+ });
+
+ let handle = self
+ .active_context_editor
+ .update(cx, |this, _| this.slash_menu_handle.clone())
+ .ok();
+ PopoverMenu::new("model-switcher")
+ .menu(move |_cx| Some(picker_view.clone()))
+ .trigger(self.trigger)
+ .attach(gpui::AnchorCorner::TopLeft)
+ .anchor(gpui::AnchorCorner::BottomLeft)
+ .offset(gpui::Point {
+ x: px(0.0),
+ y: px(-16.0),
+ })
+ .when_some(handle, |this, handle| this.with_handle(handle))
+ }
+}
@@ -1,42 +0,0 @@
-## Assistant Panel
-
-Once you have configured a provider, you can interact with the provider's language models in a context editor.
-
-To create a new context editor, use the menu in the top right of the assistant panel and select the `New Context` option.
-
-In the context editor, select a model from one of the configured providers, type a message in the `You` block, and submit with `cmd-enter` (or `ctrl-enter` on Linux).
-
-### Adding Prompts
-
-You can customize the default prompts used in new context editors by opening the `Prompt Library`.
-
-Open the `Prompt Library` using either the menu in the top right of the assistant panel and choosing the `Prompt Library` option, or by using the `assistant: deploy prompt library` command when the assistant panel is focused.
-
-### Viewing past contexts
-
-You can view all previous contexts by opening the `History` tab in the assistant panel.
-
-Open the `History` using the menu in the top right of the assistant panel and choosing `History`.
-
-### Slash commands
-
-Slash commands enhance the assistant's capabilities. Begin by typing a `/` at the beginning of the line to see a list of available commands:
-
-- default: Inserts the default prompt into the context
-- diagnostics: Injects errors reported by the project's language server into the context
-- fetch: Pulls the content of a webpage and inserts it into the context
-- file: Pulls a single file or a directory of files into the context
-- now: Inserts the current date and time into the context
-- prompt: Adds a custom-configured prompt to the context (see Prompt Library)
-- search: Performs semantic search for content in your project based on natural language
-- symbols: Pulls the current tab's active symbols into the context
-- tab: Pulls in the content of the active tab or all open tabs into the context
-- terminal: Pulls in a select number of lines of output from the terminal
-
-## Inline assistant
-
-You can use `ctrl-enter` to open the inline assistant in both a normal editor and within the assistant panel.
-
-The inline assistant allows you to send the current selection (or the current line) to a language model and modify the selection with the language model's response.
-
-The inline assistant pulls its context from the assistant panel, allowing you to provide additional instructions or rules for code transformations.
@@ -273,7 +273,7 @@ impl Item for WorkflowStepView {
}
fn tab_icon(&self, _cx: &WindowContext) -> Option<ui::Icon> {
- Some(Icon::new(IconName::Pencil))
+ Some(Icon::new(IconName::SearchCode))
}
fn to_item_events(event: &Self::Event, mut f: impl FnMut(item::ItemEvent)) {
@@ -295,7 +295,8 @@ CREATE UNIQUE INDEX "index_channel_buffer_collaborators_on_channel_id_connection
CREATE TABLE "feature_flags" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
- "flag" TEXT NOT NULL UNIQUE
+ "flag" TEXT NOT NULL UNIQUE,
+ "enabled_for_all" BOOLEAN NOT NULL DEFAULT false
);
CREATE INDEX "index_feature_flags" ON "feature_flags" ("id");
@@ -0,0 +1 @@
+alter table feature_flags add column enabled_for_all boolean not null default false;
@@ -1,5 +1,6 @@
use super::ips_file::IpsFile;
use crate::api::CloudflareIpCountryHeader;
+use crate::clickhouse::write_to_table;
use crate::{api::slack, AppState, Error, Result};
use anyhow::{anyhow, Context};
use aws_sdk_s3::primitives::ByteStream;
@@ -529,12 +530,12 @@ struct ToUpload {
impl ToUpload {
pub async fn upload(&self, clickhouse_client: &clickhouse::Client) -> anyhow::Result<()> {
const EDITOR_EVENTS_TABLE: &str = "editor_events";
- Self::upload_to_table(EDITOR_EVENTS_TABLE, &self.editor_events, clickhouse_client)
+ write_to_table(EDITOR_EVENTS_TABLE, &self.editor_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{EDITOR_EVENTS_TABLE}'"))?;
const INLINE_COMPLETION_EVENTS_TABLE: &str = "inline_completion_events";
- Self::upload_to_table(
+ write_to_table(
INLINE_COMPLETION_EVENTS_TABLE,
&self.inline_completion_events,
clickhouse_client,
@@ -543,7 +544,7 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{INLINE_COMPLETION_EVENTS_TABLE}'"))?;
const ASSISTANT_EVENTS_TABLE: &str = "assistant_events";
- Self::upload_to_table(
+ write_to_table(
ASSISTANT_EVENTS_TABLE,
&self.assistant_events,
clickhouse_client,
@@ -552,27 +553,27 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{ASSISTANT_EVENTS_TABLE}'"))?;
const CALL_EVENTS_TABLE: &str = "call_events";
- Self::upload_to_table(CALL_EVENTS_TABLE, &self.call_events, clickhouse_client)
+ write_to_table(CALL_EVENTS_TABLE, &self.call_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{CALL_EVENTS_TABLE}'"))?;
const CPU_EVENTS_TABLE: &str = "cpu_events";
- Self::upload_to_table(CPU_EVENTS_TABLE, &self.cpu_events, clickhouse_client)
+ write_to_table(CPU_EVENTS_TABLE, &self.cpu_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{CPU_EVENTS_TABLE}'"))?;
const MEMORY_EVENTS_TABLE: &str = "memory_events";
- Self::upload_to_table(MEMORY_EVENTS_TABLE, &self.memory_events, clickhouse_client)
+ write_to_table(MEMORY_EVENTS_TABLE, &self.memory_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{MEMORY_EVENTS_TABLE}'"))?;
const APP_EVENTS_TABLE: &str = "app_events";
- Self::upload_to_table(APP_EVENTS_TABLE, &self.app_events, clickhouse_client)
+ write_to_table(APP_EVENTS_TABLE, &self.app_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{APP_EVENTS_TABLE}'"))?;
const SETTING_EVENTS_TABLE: &str = "setting_events";
- Self::upload_to_table(
+ write_to_table(
SETTING_EVENTS_TABLE,
&self.setting_events,
clickhouse_client,
@@ -581,7 +582,7 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{SETTING_EVENTS_TABLE}'"))?;
const EXTENSION_EVENTS_TABLE: &str = "extension_events";
- Self::upload_to_table(
+ write_to_table(
EXTENSION_EVENTS_TABLE,
&self.extension_events,
clickhouse_client,
@@ -590,48 +591,22 @@ impl ToUpload {
.with_context(|| format!("failed to upload to table '{EXTENSION_EVENTS_TABLE}'"))?;
const EDIT_EVENTS_TABLE: &str = "edit_events";
- Self::upload_to_table(EDIT_EVENTS_TABLE, &self.edit_events, clickhouse_client)
+ write_to_table(EDIT_EVENTS_TABLE, &self.edit_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{EDIT_EVENTS_TABLE}'"))?;
const ACTION_EVENTS_TABLE: &str = "action_events";
- Self::upload_to_table(ACTION_EVENTS_TABLE, &self.action_events, clickhouse_client)
+ write_to_table(ACTION_EVENTS_TABLE, &self.action_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{ACTION_EVENTS_TABLE}'"))?;
const REPL_EVENTS_TABLE: &str = "repl_events";
- Self::upload_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
+ write_to_table(REPL_EVENTS_TABLE, &self.repl_events, clickhouse_client)
.await
.with_context(|| format!("failed to upload to table '{REPL_EVENTS_TABLE}'"))?;
Ok(())
}
-
- async fn upload_to_table<T: clickhouse::Row + Serialize + std::fmt::Debug>(
- table: &str,
- rows: &[T],
- clickhouse_client: &clickhouse::Client,
- ) -> anyhow::Result<()> {
- if rows.is_empty() {
- return Ok(());
- }
-
- let mut insert = clickhouse_client.insert(table)?;
-
- for event in rows {
- insert.write(event).await?;
- }
-
- insert.end().await?;
-
- let event_count = rows.len();
- log::info!(
- "wrote {event_count} {event_specifier} to '{table}'",
- event_specifier = if event_count == 1 { "event" } else { "events" }
- );
-
- Ok(())
- }
}
pub fn serialize_country_code<S>(country_code: &str, serializer: S) -> Result<S::Ok, S::Error>
@@ -0,0 +1,28 @@
+use serde::Serialize;
+
+/// Writes the given rows to the specified Clickhouse table.
+pub async fn write_to_table<T: clickhouse::Row + Serialize + std::fmt::Debug>(
+ table: &str,
+ rows: &[T],
+ clickhouse_client: &clickhouse::Client,
+) -> anyhow::Result<()> {
+ if rows.is_empty() {
+ return Ok(());
+ }
+
+ let mut insert = clickhouse_client.insert(table)?;
+
+ for event in rows {
+ insert.write(event).await?;
+ }
+
+ insert.end().await?;
+
+ let event_count = rows.len();
+ log::info!(
+ "wrote {event_count} {event_specifier} to '{table}'",
+ event_specifier = if event_count == 1 { "event" } else { "events" }
+ );
+
+ Ok(())
+}
@@ -312,10 +312,11 @@ impl Database {
}
/// Creates a new feature flag.
- pub async fn create_user_flag(&self, flag: &str) -> Result<FlagId> {
+ pub async fn create_user_flag(&self, flag: &str, enabled_for_all: bool) -> Result<FlagId> {
self.transaction(|tx| async move {
let flag = feature_flag::Entity::insert(feature_flag::ActiveModel {
flag: ActiveValue::set(flag.to_string()),
+ enabled_for_all: ActiveValue::set(enabled_for_all),
..Default::default()
})
.exec(&*tx)
@@ -350,7 +351,15 @@ impl Database {
Flag,
}
- let flags = user::Model {
+ let flags_enabled_for_all = feature_flag::Entity::find()
+ .filter(feature_flag::Column::EnabledForAll.eq(true))
+ .select_only()
+ .column(feature_flag::Column::Flag)
+ .into_values::<_, QueryAs>()
+ .all(&*tx)
+ .await?;
+
+ let flags_enabled_for_user = user::Model {
id: user,
..Default::default()
}
@@ -361,7 +370,10 @@ impl Database {
.all(&*tx)
.await?;
- Ok(flags)
+ let mut all_flags = HashSet::from_iter(flags_enabled_for_all);
+ all_flags.extend(flags_enabled_for_user);
+
+ Ok(all_flags.into_iter().collect())
})
.await
}
@@ -8,6 +8,7 @@ pub struct Model {
#[sea_orm(primary_key)]
pub id: FlagId,
pub flag: String,
+ pub enabled_for_all: bool,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -2,6 +2,7 @@ use crate::{
db::{Database, NewUserParams},
test_both_dbs,
};
+use pretty_assertions::assert_eq;
use std::sync::Arc;
test_both_dbs!(
@@ -37,22 +38,27 @@ async fn test_get_user_flags(db: &Arc<Database>) {
.unwrap()
.user_id;
- const CHANNELS_ALPHA: &str = "channels-alpha";
- const NEW_SEARCH: &str = "new-search";
+ const FEATURE_FLAG_ONE: &str = "brand-new-ux";
+ const FEATURE_FLAG_TWO: &str = "cool-feature";
+ const FEATURE_FLAG_THREE: &str = "feature-enabled-for-everyone";
- let channels_flag = db.create_user_flag(CHANNELS_ALPHA).await.unwrap();
- let search_flag = db.create_user_flag(NEW_SEARCH).await.unwrap();
+ let feature_flag_one = db.create_user_flag(FEATURE_FLAG_ONE, false).await.unwrap();
+ let feature_flag_two = db.create_user_flag(FEATURE_FLAG_TWO, false).await.unwrap();
+ db.create_user_flag(FEATURE_FLAG_THREE, true).await.unwrap();
- db.add_user_flag(user_1, channels_flag).await.unwrap();
- db.add_user_flag(user_1, search_flag).await.unwrap();
+ db.add_user_flag(user_1, feature_flag_one).await.unwrap();
+ db.add_user_flag(user_1, feature_flag_two).await.unwrap();
- db.add_user_flag(user_2, channels_flag).await.unwrap();
+ db.add_user_flag(user_2, feature_flag_one).await.unwrap();
let mut user_1_flags = db.get_user_flags(user_1).await.unwrap();
user_1_flags.sort();
- assert_eq!(user_1_flags, &[CHANNELS_ALPHA, NEW_SEARCH]);
+ assert_eq!(
+ user_1_flags,
+ &[FEATURE_FLAG_ONE, FEATURE_FLAG_TWO, FEATURE_FLAG_THREE]
+ );
let mut user_2_flags = db.get_user_flags(user_2).await.unwrap();
user_2_flags.sort();
- assert_eq!(user_2_flags, &[CHANNELS_ALPHA]);
+ assert_eq!(user_2_flags, &[FEATURE_FLAG_ONE, FEATURE_FLAG_THREE]);
}
@@ -1,5 +1,6 @@
pub mod api;
pub mod auth;
+pub mod clickhouse;
pub mod db;
pub mod env;
pub mod executor;
@@ -267,7 +268,7 @@ pub struct AppState {
pub stripe_client: Option<Arc<stripe::Client>>,
pub rate_limiter: Arc<RateLimiter>,
pub executor: Executor,
- pub clickhouse_client: Option<clickhouse::Client>,
+ pub clickhouse_client: Option<::clickhouse::Client>,
pub config: Config,
}
@@ -358,8 +359,8 @@ async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::
Ok(aws_sdk_s3::Client::new(&s3_config))
}
-fn build_clickhouse_client(config: &Config) -> anyhow::Result<clickhouse::Client> {
- Ok(clickhouse::Client::default()
+fn build_clickhouse_client(config: &Config) -> anyhow::Result<::clickhouse::Client> {
+ Ok(::clickhouse::Client::default()
.with_url(
config
.clickhouse_url
@@ -141,7 +141,8 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
tracing::Span::current()
.record("user_id", claims.user_id)
.record("login", claims.github_user_login.clone())
- .record("authn.jti", &claims.jti);
+ .record("authn.jti", &claims.jti)
+ .record("is_staff", &claims.is_staff);
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
@@ -169,7 +170,10 @@ async fn perform_completion(
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
- let model = normalize_model_name(params.provider, params.model);
+ let model = normalize_model_name(
+ state.db.model_names_for_provider(params.provider),
+ params.model,
+ );
authorize_access_to_language_model(
&state.config,
@@ -200,17 +204,21 @@ async fn perform_completion(
let mut request: anthropic::Request =
serde_json::from_str(¶ms.provider_request.get())?;
- // Parse the model, throw away the version that was included, and then set a specific
- // version that we control on the server.
+ // Override the model on the request with the latest version of the model that is
+ // known to the server.
+ //
// Right now, we use the version that's defined in `model.id()`, but we will likely
// want to change this code once a new version of an Anthropic model is released,
// so that users can use the new version, without having to update Zed.
- request.model = match anthropic::Model::from_id(&request.model) {
- Ok(model) => model.id().to_string(),
- Err(_) => request.model,
+ request.model = match model.as_str() {
+ "claude-3-5-sonnet" => anthropic::Model::Claude3_5Sonnet.id().to_string(),
+ "claude-3-opus" => anthropic::Model::Claude3Opus.id().to_string(),
+ "claude-3-haiku" => anthropic::Model::Claude3Haiku.id().to_string(),
+ "claude-3-sonnet" => anthropic::Model::Claude3Sonnet.id().to_string(),
+ _ => request.model,
};
- let chunks = anthropic::stream_completion(
+ let (chunks, rate_limit_info) = anthropic::stream_completion_with_rate_limit_info(
&state.http_client,
anthropic::ANTHROPIC_API_URL,
api_key,
@@ -238,6 +246,19 @@ async fn perform_completion(
anthropic::AnthropicError::Other(err) => Error::Internal(err),
})?;
+ if let Some(rate_limit_info) = rate_limit_info {
+ tracing::info!(
+ target: "upstream rate limit",
+ is_staff = claims.is_staff,
+ provider = params.provider.to_string(),
+ model = model,
+ tokens_remaining = rate_limit_info.tokens_remaining,
+ requests_remaining = rate_limit_info.requests_remaining,
+ requests_reset = ?rate_limit_info.requests_reset,
+ tokens_reset = ?rate_limit_info.tokens_reset,
+ );
+ }
+
chunks
.map(move |event| {
let chunk = event?;
@@ -369,31 +390,13 @@ async fn perform_completion(
})))
}
-fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
- let prefixes: &[_] = match provider {
- LanguageModelProvider::Anthropic => &[
- "claude-3-5-sonnet",
- "claude-3-haiku",
- "claude-3-opus",
- "claude-3-sonnet",
- ],
- LanguageModelProvider::OpenAi => &[
- "gpt-3.5-turbo",
- "gpt-4-turbo-preview",
- "gpt-4o-mini",
- "gpt-4o",
- "gpt-4",
- ],
- LanguageModelProvider::Google => &[],
- LanguageModelProvider::Zed => &[],
- };
-
- if let Some(prefix) = prefixes
+fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
+ if let Some(known_model_name) = known_models
.iter()
- .filter(|&&prefix| name.starts_with(prefix))
- .max_by_key(|&&prefix| prefix.len())
+ .filter(|known_model_name| name.starts_with(known_model_name.as_str()))
+ .max_by_key(|known_model_name| known_model_name.len())
{
- prefix.to_string()
+ known_model_name.to_string()
} else {
name
}
@@ -551,33 +554,75 @@ impl<S> Drop for TokenCountingStream<S> {
.await
.log_err();
- if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
- report_llm_usage(
- clickhouse_client,
- LlmUsageEventRow {
- time: Utc::now().timestamp_millis(),
- user_id: claims.user_id as i32,
- is_staff: claims.is_staff,
- plan: match claims.plan {
- Plan::Free => "free".to_string(),
- Plan::ZedPro => "zed_pro".to_string(),
+ if let Some(usage) = usage {
+ tracing::info!(
+ target: "user usage",
+ user_id = claims.user_id,
+ login = claims.github_user_login,
+ authn.jti = claims.jti,
+ is_staff = claims.is_staff,
+ requests_this_minute = usage.requests_this_minute,
+ tokens_this_minute = usage.tokens_this_minute,
+ );
+
+ if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
+ report_llm_usage(
+ clickhouse_client,
+ LlmUsageEventRow {
+ time: Utc::now().timestamp_millis(),
+ user_id: claims.user_id as i32,
+ is_staff: claims.is_staff,
+ plan: match claims.plan {
+ Plan::Free => "free".to_string(),
+ Plan::ZedPro => "zed_pro".to_string(),
+ },
+ model,
+ provider: provider.to_string(),
+ input_token_count: input_token_count as u64,
+ output_token_count: output_token_count as u64,
+ requests_this_minute: usage.requests_this_minute as u64,
+ tokens_this_minute: usage.tokens_this_minute as u64,
+ tokens_this_day: usage.tokens_this_day as u64,
+ input_tokens_this_month: usage.input_tokens_this_month as u64,
+ output_tokens_this_month: usage.output_tokens_this_month as u64,
+ spending_this_month: usage.spending_this_month as u64,
+ lifetime_spending: usage.lifetime_spending as u64,
},
- model,
- provider: provider.to_string(),
- input_token_count: input_token_count as u64,
- output_token_count: output_token_count as u64,
- requests_this_minute: usage.requests_this_minute as u64,
- tokens_this_minute: usage.tokens_this_minute as u64,
- tokens_this_day: usage.tokens_this_day as u64,
- input_tokens_this_month: usage.input_tokens_this_month as u64,
- output_tokens_this_month: usage.output_tokens_this_month as u64,
- spending_this_month: usage.spending_this_month as u64,
- lifetime_spending: usage.lifetime_spending as u64,
- },
- )
- .await
- .log_err();
+ )
+ .await
+ .log_err();
+ }
}
})
}
}
+
+pub fn log_usage_periodically(state: Arc<LlmState>) {
+ state.executor.clone().spawn_detached(async move {
+ loop {
+ state
+ .executor
+ .sleep(std::time::Duration::from_secs(30))
+ .await;
+
+ let Some(usages) = state
+ .db
+ .get_application_wide_usages_by_model(Utc::now())
+ .await
+ .log_err()
+ else {
+ continue;
+ };
+
+ for usage in usages {
+ tracing::info!(
+ target: "computed usage",
+ provider = usage.provider.to_string(),
+ model = usage.model,
+ requests_this_minute = usage.requests_this_minute,
+ tokens_this_minute = usage.tokens_this_minute,
+ );
+ }
+ }
+ })
+}
@@ -26,9 +26,7 @@ fn authorize_access_to_model(
}
match (provider, model) {
- (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
- Ok(())
- }
+ (LanguageModelProvider::Anthropic, "claude-3-5-sonnet") => Ok(()),
_ => Err(Error::http(
StatusCode::FORBIDDEN,
format!("access to model {model:?} is not included in your plan"),
@@ -67,6 +67,21 @@ impl LlmDatabase {
Ok(())
}
+ /// Returns the names of the known models for the given [`LanguageModelProvider`].
+ pub fn model_names_for_provider(&self, provider: LanguageModelProvider) -> Vec<String> {
+ self.models
+ .keys()
+ .filter_map(|(model_provider, model_name)| {
+ if model_provider == &provider {
+ Some(model_name)
+ } else {
+ None
+ }
+ })
+ .cloned()
+ .collect::<Vec<_>>()
+ }
+
pub fn model(&self, provider: LanguageModelProvider, name: &str) -> Result<&model::Model> {
Ok(self
.models
@@ -1,5 +1,6 @@
use crate::db::UserId;
use chrono::Duration;
+use futures::StreamExt as _;
use rpc::LanguageModelProvider;
use sea_orm::QuerySelect;
use std::{iter, str::FromStr};
@@ -18,6 +19,14 @@ pub struct Usage {
pub lifetime_spending: usize,
}
+#[derive(Debug, PartialEq, Clone)]
+pub struct ApplicationWideUsage {
+ pub provider: LanguageModelProvider,
+ pub model: String,
+ pub requests_this_minute: usize,
+ pub tokens_this_minute: usize,
+}
+
#[derive(Clone, Copy, Debug, Default)]
pub struct ActiveUserCount {
pub users_in_recent_minutes: usize,
@@ -63,6 +72,71 @@ impl LlmDatabase {
Ok(())
}
+ pub async fn get_application_wide_usages_by_model(
+ &self,
+ now: DateTimeUtc,
+ ) -> Result<Vec<ApplicationWideUsage>> {
+ self.transaction(|tx| async move {
+ let past_minute = now - Duration::minutes(1);
+ let requests_per_minute = self.usage_measure_ids[&UsageMeasure::RequestsPerMinute];
+ let tokens_per_minute = self.usage_measure_ids[&UsageMeasure::TokensPerMinute];
+
+ let mut results = Vec::new();
+ for (provider, model) in self.models.keys().cloned() {
+ let mut usages = usage::Entity::find()
+ .filter(
+ usage::Column::Timestamp
+ .gte(past_minute.naive_utc())
+ .and(usage::Column::IsStaff.eq(false))
+ .and(
+ usage::Column::MeasureId
+ .eq(requests_per_minute)
+ .or(usage::Column::MeasureId.eq(tokens_per_minute)),
+ ),
+ )
+ .stream(&*tx)
+ .await?;
+
+ let mut requests_this_minute = 0;
+ let mut tokens_this_minute = 0;
+ while let Some(usage) = usages.next().await {
+ let usage = usage?;
+ if usage.measure_id == requests_per_minute {
+ requests_this_minute += Self::get_live_buckets(
+ &usage,
+ now.naive_utc(),
+ UsageMeasure::RequestsPerMinute,
+ )
+ .0
+ .iter()
+ .copied()
+ .sum::<i64>() as usize;
+ } else if usage.measure_id == tokens_per_minute {
+ tokens_this_minute += Self::get_live_buckets(
+ &usage,
+ now.naive_utc(),
+ UsageMeasure::TokensPerMinute,
+ )
+ .0
+ .iter()
+ .copied()
+ .sum::<i64>() as usize;
+ }
+ }
+
+ results.push(ApplicationWideUsage {
+ provider,
+ model,
+ requests_this_minute,
+ tokens_this_minute,
+ })
+ }
+
+ Ok(results)
+ })
+ .await
+ }
+
pub async fn get_usage(
&self,
user_id: UserId,
@@ -1,6 +1,8 @@
-use anyhow::Result;
+use anyhow::{Context, Result};
use serde::Serialize;
+use crate::clickhouse::write_to_table;
+
#[derive(Serialize, Debug, clickhouse::Row)]
pub struct LlmUsageEventRow {
pub time: i64,
@@ -40,9 +42,10 @@ pub struct LlmRateLimitEventRow {
}
pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> {
- let mut insert = client.insert("llm_usage_events")?;
- insert.write(&row).await?;
- insert.end().await?;
+ const LLM_USAGE_EVENTS_TABLE: &str = "llm_usage_events";
+ write_to_table(LLM_USAGE_EVENTS_TABLE, &[row], client)
+ .await
+ .with_context(|| format!("failed to upload to table '{LLM_USAGE_EVENTS_TABLE}'"))?;
Ok(())
}
@@ -50,8 +53,9 @@ pub async fn report_llm_rate_limit(
client: &clickhouse::Client,
row: LlmRateLimitEventRow,
) -> Result<()> {
- let mut insert = client.insert("llm_rate_limits")?;
- insert.write(&row).await?;
- insert.end().await?;
+ const LLM_RATE_LIMIT_EVENTS_TABLE: &str = "llm_rate_limit_events";
+ write_to_table(LLM_RATE_LIMIT_EVENTS_TABLE, &[row], client)
+ .await
+ .with_context(|| format!("failed to upload to table '{LLM_RATE_LIMIT_EVENTS_TABLE}'"))?;
Ok(())
}
@@ -5,7 +5,7 @@ use axum::{
routing::get,
Extension, Router,
};
-use collab::llm::db::LlmDatabase;
+use collab::llm::{db::LlmDatabase, log_usage_periodically};
use collab::migrations::run_database_migrations;
use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
use collab::{
@@ -95,6 +95,8 @@ async fn main() -> Result<()> {
let state = LlmState::new(config.clone(), Executor::Production).await?;
+ log_usage_periodically(state.clone());
+
app = app
.merge(collab::llm::routes())
.layer(Extension(state.clone()));
@@ -152,7 +154,8 @@ async fn main() -> Result<()> {
matched_path,
user_id = tracing::field::Empty,
login = tracing::field::Empty,
- authn.jti = tracing::field::Empty
+ authn.jti = tracing::field::Empty,
+ is_staff = tracing::field::Empty
)
})
.on_response(
@@ -1,7 +1,6 @@
use crate::db::{self, ChannelRole, NewUserParams};
use anyhow::Context;
-use chrono::{DateTime, Utc};
use db::Database;
use serde::{de::DeserializeOwned, Deserialize};
use std::{fmt::Write, fs, path::Path};
@@ -13,7 +12,6 @@ struct GitHubUser {
id: i32,
login: String,
email: Option<String>,
- created_at: DateTime<Utc>,
}
#[derive(Deserialize)]
@@ -44,6 +42,17 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
let mut first_user = None;
let mut others = vec![];
+ let flag_names = ["remoting", "language-models"];
+ let mut flags = Vec::new();
+
+ for flag_name in flag_names {
+ let flag = db
+ .create_user_flag(flag_name, false)
+ .await
+ .unwrap_or_else(|_| panic!("failed to create flag: '{flag_name}'"));
+ flags.push(flag);
+ }
+
for admin_login in seed_config.admins {
let user = fetch_github::<GitHubUser>(
&client,
@@ -66,6 +75,15 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
} else {
others.push(user.user_id)
}
+
+ for flag in &flags {
+ db.add_user_flag(user.user_id, *flag)
+ .await
+ .context(format!(
+ "Unable to enable flag '{}' for user '{}'",
+ flag, user.user_id
+ ))?;
+ }
}
for channel in seed_config.channels {
@@ -86,6 +104,7 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
}
}
+ // TODO: Fix this later
if let Some(number_of_users) = seed_config.number_of_users {
// Fetch 100 other random users from GitHub and insert them into the database
// (for testing autocompleters, etc.)
@@ -105,15 +124,23 @@ pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result
for github_user in users {
last_user_id = Some(github_user.id);
user_count += 1;
- db.get_or_create_user_by_github_account(
- &github_user.login,
- Some(github_user.id),
- github_user.email.as_deref(),
- Some(github_user.created_at),
- None,
- )
- .await
- .expect("failed to insert user");
+ let user = db
+ .get_or_create_user_by_github_account(
+ &github_user.login,
+ Some(github_user.id),
+ github_user.email.as_deref(),
+ None,
+ None,
+ )
+ .await
+ .expect("failed to insert user");
+
+ for flag in &flags {
+ db.add_user_flag(user.id, *flag).await.context(format!(
+ "Unable to enable flag '{}' for user '{}'",
+ flag, user.id
+ ))?;
+ }
}
}
}
@@ -132,9 +159,9 @@ async fn fetch_github<T: DeserializeOwned>(client: &reqwest::Client, url: &str)
.header("user-agent", "zed")
.send()
.await
- .unwrap_or_else(|_| panic!("failed to fetch '{}'", url));
+ .unwrap_or_else(|error| panic!("failed to fetch '{url}': {error}"));
response
.json()
.await
- .unwrap_or_else(|_| panic!("failed to deserialize github user from '{}'", url))
+ .unwrap_or_else(|error| panic!("failed to deserialize github user from '{url}': {error}"))
}
@@ -30,7 +30,9 @@ fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewC
let model = ContextServerManager::global(&cx);
cx.update_model(&model, |manager, cx| {
for server in manager.servers() {
- manager.restart_server(&server.id, cx).detach();
+ manager
+ .restart_server(&server.id, cx)
+ .detach_and_log_err(cx);
}
});
}
@@ -266,11 +266,11 @@ pub fn init(cx: &mut AppContext) {
log::trace!("servers_to_add={:?}", servers_to_add);
for config in servers_to_add {
- manager.add_server(config, cx).detach();
+ manager.add_server(config, cx).detach_and_log_err(cx);
}
for id in servers_to_remove {
- manager.remove_server(&id, cx).detach();
+ manager.remove_server(&id, cx).detach_and_log_err(cx);
}
})
})
@@ -31,6 +31,8 @@ pub enum Role {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
#[default]
+ #[serde(alias = "gpt-4o", rename = "gpt-4o-2024-05-13")]
+ Gpt4o,
#[serde(alias = "gpt-4", rename = "gpt-4")]
Gpt4,
#[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")]
@@ -40,6 +42,7 @@ pub enum Model {
impl Model {
pub fn from_id(id: &str) -> Result<Self> {
match id {
+ "gpt-4o" => Ok(Self::Gpt4o),
"gpt-4" => Ok(Self::Gpt4),
"gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo),
_ => Err(anyhow!("Invalid model id: {}", id)),
@@ -50,6 +53,7 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "gpt-3.5-turbo",
Self::Gpt4 => "gpt-4",
+ Self::Gpt4o => "gpt-4o",
}
}
@@ -57,11 +61,13 @@ impl Model {
match self {
Self::Gpt3_5Turbo => "GPT-3.5",
Self::Gpt4 => "GPT-4",
+ Self::Gpt4o => "GPT-4o",
}
}
pub fn max_token_count(&self) -> usize {
match self {
+ Self::Gpt4o => 128000,
Self::Gpt4 => 8192,
Self::Gpt3_5Turbo => 16385,
}
@@ -12,11 +12,11 @@ use ui::{
use util::paths::FILE_ROW_COLUMN_DELIMITER;
use workspace::{item::ItemHandle, StatusItemView, Workspace};
-#[derive(Copy, Clone, Default, PartialOrd, PartialEq)]
-struct SelectionStats {
- lines: usize,
- characters: usize,
- selections: usize,
+#[derive(Copy, Clone, Debug, Default, PartialOrd, PartialEq)]
+pub(crate) struct SelectionStats {
+ pub lines: usize,
+ pub characters: usize,
+ pub selections: usize,
}
pub struct CursorPosition {
@@ -44,7 +44,10 @@ impl CursorPosition {
self.selected_count.selections = editor.selections.count();
let mut last_selection: Option<Selection<usize>> = None;
for selection in editor.selections.all::<usize>(cx) {
- self.selected_count.characters += selection.end - selection.start;
+ self.selected_count.characters += buffer
+ .text_for_range(selection.start..selection.end)
+ .map(|t| t.chars().count())
+ .sum::<usize>();
if last_selection
.as_ref()
.map_or(true, |last_selection| selection.id > last_selection.id)
@@ -106,6 +109,11 @@ impl CursorPosition {
}
text.push(')');
}
+
+ #[cfg(test)]
+ pub(crate) fn selection_stats(&self) -> &SelectionStats {
+ &self.selected_count
+ }
}
impl Render for CursorPosition {
@@ -221,6 +221,8 @@ impl Render for GoToLine {
#[cfg(test)]
mod tests {
use super::*;
+ use cursor_position::{CursorPosition, SelectionStats};
+ use editor::actions::SelectAll;
use gpui::{TestAppContext, VisualTestContext};
use indoc::indoc;
use project::{FakeFs, Project};
@@ -335,6 +337,83 @@ mod tests {
assert_single_caret_at_row(&editor, expected_highlighted_row, cx);
}
+ #[gpui::test]
+ async fn test_unicode_characters_selection(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/dir",
+ json!({
+ "a.rs": "ēlo"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, ["/dir".as_ref()], cx).await;
+ let (workspace, cx) = cx.add_window_view(|cx| Workspace::test_new(project.clone(), cx));
+ workspace.update(cx, |workspace, cx| {
+ let cursor_position = cx.new_view(|_| CursorPosition::new(workspace));
+ workspace.status_bar().update(cx, |status_bar, cx| {
+ status_bar.add_right_item(cursor_position, cx);
+ });
+ });
+
+ let worktree_id = workspace.update(cx, |workspace, cx| {
+ workspace.project().update(cx, |project, cx| {
+ project.worktrees(cx).next().unwrap().read(cx).id()
+ })
+ });
+ let _buffer = project
+ .update(cx, |project, cx| project.open_local_buffer("/dir/a.rs", cx))
+ .await
+ .unwrap();
+ let editor = workspace
+ .update(cx, |workspace, cx| {
+ workspace.open_path((worktree_id, "a.rs"), None, true, cx)
+ })
+ .await
+ .unwrap()
+ .downcast::<Editor>()
+ .unwrap();
+
+ workspace.update(cx, |workspace, cx| {
+ assert_eq!(
+ &SelectionStats {
+ lines: 0,
+ characters: 0,
+ selections: 1,
+ },
+ workspace
+ .status_bar()
+ .read(cx)
+ .item_of_type::<CursorPosition>()
+ .expect("missing cursor position item")
+ .read(cx)
+ .selection_stats(),
+ "No selections should be initially"
+ );
+ });
+ editor.update(cx, |editor, cx| editor.select_all(&SelectAll, cx));
+ workspace.update(cx, |workspace, cx| {
+ assert_eq!(
+ &SelectionStats {
+ lines: 1,
+ characters: 3,
+ selections: 1,
+ },
+ workspace
+ .status_bar()
+ .read(cx)
+ .item_of_type::<CursorPosition>()
+ .expect("missing cursor position item")
+ .read(cx)
+ .selection_stats(),
+ "After selecting a text with multibyte unicode characters, the character count should be correct"
+ );
+ });
+ }
+
fn open_go_to_line_view(
workspace: &View<Workspace>,
cx: &mut VisualTestContext,
@@ -6,7 +6,7 @@ use std::{
path::{Path, PathBuf},
rc::{Rc, Weak},
sync::{atomic::Ordering::SeqCst, Arc},
- time::Duration,
+ time::{Duration, Instant},
};
use anyhow::{anyhow, Result};
@@ -142,6 +142,12 @@ impl App {
self
}
+ /// Sets a start time for tracking time to first window draw.
+ pub fn measure_time_to_first_window_draw(self, start: Instant) -> Self {
+ self.0.borrow_mut().time_to_first_window_draw = Some(TimeToFirstWindowDraw::Pending(start));
+ self
+ }
+
/// Start the application. The provided callback will be called once the
/// app is fully launched.
pub fn run<F>(self, on_finish_launching: F)
@@ -247,6 +253,7 @@ pub struct AppContext {
pub(crate) layout_id_buffer: Vec<LayoutId>, // We recycle this memory across layout requests.
pub(crate) propagate_event: bool,
pub(crate) prompt_builder: Option<PromptBuilder>,
+ pub(crate) time_to_first_window_draw: Option<TimeToFirstWindowDraw>,
}
impl AppContext {
@@ -300,6 +307,7 @@ impl AppContext {
layout_id_buffer: Default::default(),
propagate_event: true,
prompt_builder: Some(PromptBuilder::Default),
+ time_to_first_window_draw: None,
}),
});
@@ -1302,6 +1310,14 @@ impl AppContext {
(task, is_first)
}
+
+ /// Returns the time to first window draw, if available.
+ pub fn time_to_first_window_draw(&self) -> Option<Duration> {
+ match self.time_to_first_window_draw {
+ Some(TimeToFirstWindowDraw::Done(duration)) => Some(duration),
+ _ => None,
+ }
+ }
}
impl Context for AppContext {
@@ -1465,6 +1481,15 @@ impl<G: Global> DerefMut for GlobalLease<G> {
}
}
+/// Represents the initialization duration of the application.
+#[derive(Clone, Copy)]
+pub enum TimeToFirstWindowDraw {
+ /// The application is still initializing, and contains the start time.
+ Pending(Instant),
+ /// The application has finished initializing, and contains the total duration.
+ Done(Duration),
+}
+
/// Contains state associated with an active drag operation, started by dragging an element
/// within the window or by dragging into the app from the underlying platform.
pub struct AnyDrag {
@@ -16,6 +16,7 @@ mod blade;
#[cfg(any(test, feature = "test-support"))]
mod test;
+mod fps;
#[cfg(target_os = "windows")]
mod windows;
@@ -51,6 +52,7 @@ use strum::EnumIter;
use uuid::Uuid;
pub use app_menu::*;
+pub use fps::*;
pub use keystroke::*;
#[cfg(target_os = "linux")]
@@ -354,7 +356,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
fn on_should_close(&self, callback: Box<dyn FnMut() -> bool>);
fn on_close(&self, callback: Box<dyn FnOnce()>);
fn on_appearance_changed(&self, callback: Box<dyn FnMut()>);
- fn draw(&self, scene: &Scene);
+ fn draw(&self, scene: &Scene, on_complete: Option<oneshot::Sender<()>>);
fn completed_frame(&self) {}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas>;
@@ -379,6 +381,7 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
}
fn set_client_inset(&self, _inset: Pixels) {}
fn gpu_specs(&self) -> Option<GPUSpecs>;
+ fn fps(&self) -> Option<f32>;
#[cfg(any(test, feature = "test-support"))]
fn as_test(&mut self) -> Option<&mut TestWindow> {
@@ -9,6 +9,7 @@ use crate::{
};
use bytemuck::{Pod, Zeroable};
use collections::HashMap;
+use futures::channel::oneshot;
#[cfg(target_os = "macos")]
use media::core_video::CVMetalTextureCache;
#[cfg(target_os = "macos")]
@@ -537,7 +538,12 @@ impl BladeRenderer {
self.gpu.destroy_command_encoder(&mut self.command_encoder);
}
- pub fn draw(&mut self, scene: &Scene) {
+ pub fn draw(
+ &mut self,
+ scene: &Scene,
+ // Required to compile on macOS, but not currently supported.
+ _on_complete: Option<oneshot::Sender<()>>,
+ ) {
self.command_encoder.start();
self.atlas.before_frame(&mut self.command_encoder);
self.rasterize_paths(scene.paths());
@@ -766,4 +772,9 @@ impl BladeRenderer {
self.wait_for_gpu();
self.last_sync_point = Some(sync_point);
}
+
+ /// Required to compile on macOS, but not currently supported.
+ pub fn fps(&self) -> f32 {
+ 0.0
+ }
}
@@ -0,0 +1,94 @@
+use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
+use std::sync::Arc;
+
+const NANOS_PER_SEC: u64 = 1_000_000_000;
+const WINDOW_SIZE: usize = 128;
+
+/// Represents a rolling FPS (Frames Per Second) counter.
+///
+/// This struct provides a lock-free mechanism to measure and calculate FPS
+/// continuously, updating with every frame. It uses atomic operations to
+/// ensure thread-safety without the need for locks.
+pub struct FpsCounter {
+ frame_times: [AtomicU64; WINDOW_SIZE],
+ head: AtomicUsize,
+ tail: AtomicUsize,
+}
+
+impl FpsCounter {
+ /// Creates a new `Fps` counter.
+ ///
+ /// Returns an `Arc<Fps>` for safe sharing across threads.
+ pub fn new() -> Arc<Self> {
+ Arc::new(Self {
+ frame_times: std::array::from_fn(|_| AtomicU64::new(0)),
+ head: AtomicUsize::new(0),
+ tail: AtomicUsize::new(0),
+ })
+ }
+
+ /// Increments the FPS counter with a new frame timestamp.
+ ///
+ /// This method updates the internal state to maintain a rolling window
+ /// of frame data for the last second. It uses atomic operations to
+ /// ensure thread-safety.
+ ///
+ /// # Arguments
+ ///
+ /// * `timestamp_ns` - The timestamp of the new frame in nanoseconds.
+ pub fn increment(&self, timestamp_ns: u64) {
+ let mut head = self.head.load(Ordering::Relaxed);
+ let mut tail = self.tail.load(Ordering::Relaxed);
+
+ // Add new timestamp
+ self.frame_times[head].store(timestamp_ns, Ordering::Relaxed);
+ // Increment head and wrap around to 0 if it reaches WINDOW_SIZE
+ head = (head + 1) % WINDOW_SIZE;
+ self.head.store(head, Ordering::Relaxed);
+
+ // Remove old timestamps (older than 1 second)
+ while tail != head {
+ let oldest = self.frame_times[tail].load(Ordering::Relaxed);
+ if timestamp_ns.wrapping_sub(oldest) <= NANOS_PER_SEC {
+ break;
+ }
+ // Increment tail and wrap around to 0 if it reaches WINDOW_SIZE
+ tail = (tail + 1) % WINDOW_SIZE;
+ self.tail.store(tail, Ordering::Relaxed);
+ }
+ }
+
+ /// Calculates and returns the current FPS.
+ ///
+ /// This method computes the FPS based on the frames recorded in the last second.
+ /// It uses atomic loads to ensure thread-safety.
+ ///
+ /// # Returns
+ ///
+ /// The calculated FPS as a `f32`, or 0.0 if no frames have been recorded.
+ pub fn fps(&self) -> f32 {
+ let head = self.head.load(Ordering::Relaxed);
+ let tail = self.tail.load(Ordering::Relaxed);
+
+ if head == tail {
+ return 0.0;
+ }
+
+ let newest =
+ self.frame_times[head.wrapping_sub(1) & (WINDOW_SIZE - 1)].load(Ordering::Relaxed);
+ let oldest = self.frame_times[tail].load(Ordering::Relaxed);
+
+ let time_diff = newest.wrapping_sub(oldest) as f32;
+ if time_diff == 0.0 {
+ return 0.0;
+ }
+
+ let frame_count = if head > tail {
+ head - tail
+ } else {
+ WINDOW_SIZE - tail + head
+ };
+
+ (frame_count as f32 - 1.0) * NANOS_PER_SEC as f32 / time_diff
+ }
+}
@@ -6,7 +6,7 @@ use std::sync::Arc;
use blade_graphics as gpu;
use collections::HashMap;
-use futures::channel::oneshot::Receiver;
+use futures::channel::oneshot;
use raw_window_handle as rwh;
use wayland_backend::client::ObjectId;
@@ -827,7 +827,7 @@ impl PlatformWindow for WaylandWindow {
_msg: &str,
_detail: Option<&str>,
_answers: &[&str],
- ) -> Option<Receiver<usize>> {
+ ) -> Option<oneshot::Receiver<usize>> {
None
}
@@ -934,9 +934,9 @@ impl PlatformWindow for WaylandWindow {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
- fn draw(&self, scene: &Scene) {
+ fn draw(&self, scene: &Scene, on_complete: Option<oneshot::Sender<()>>) {
let mut state = self.borrow_mut();
- state.renderer.draw(scene);
+ state.renderer.draw(scene, on_complete);
}
fn completed_frame(&self) {
@@ -1009,6 +1009,10 @@ impl PlatformWindow for WaylandWindow {
fn gpu_specs(&self) -> Option<GPUSpecs> {
self.borrow().renderer.gpu_specs().into()
}
+
+ fn fps(&self) -> Option<f32> {
+ None
+ }
}
fn update_window(mut state: RefMut<WaylandWindowState>) {
@@ -1,5 +1,3 @@
-use anyhow::Context;
-
use crate::{
platform::blade::{BladeRenderer, BladeSurfaceConfig},
px, size, AnyWindowHandle, Bounds, Decorations, DevicePixels, ForegroundExecutor, GPUSpecs,
@@ -9,7 +7,9 @@ use crate::{
X11ClientStatePtr,
};
+use anyhow::Context;
use blade_graphics as gpu;
+use futures::channel::oneshot;
use raw_window_handle as rwh;
use util::{maybe, ResultExt};
use x11rb::{
@@ -1210,9 +1210,10 @@ impl PlatformWindow for X11Window {
self.0.callbacks.borrow_mut().appearance_changed = Some(callback);
}
- fn draw(&self, scene: &Scene) {
+ // TODO: on_complete not yet supported for X11 windows
+ fn draw(&self, scene: &Scene, on_complete: Option<oneshot::Sender<()>>) {
let mut inner = self.0.state.borrow_mut();
- inner.renderer.draw(scene);
+ inner.renderer.draw(scene, on_complete);
}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> {
@@ -1398,4 +1399,8 @@ impl PlatformWindow for X11Window {
fn gpu_specs(&self) -> Option<GPUSpecs> {
self.0.state.borrow().renderer.gpu_specs().into()
}
+
+ fn fps(&self) -> Option<f32> {
+ None
+ }
}
@@ -1,7 +1,7 @@
use super::metal_atlas::MetalAtlas;
use crate::{
point, size, AtlasTextureId, AtlasTextureKind, AtlasTile, Bounds, ContentMask, DevicePixels,
- Hsla, MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite,
+ FpsCounter, Hsla, MonochromeSprite, PaintSurface, Path, PathId, PathVertex, PolychromeSprite,
PrimitiveBatch, Quad, ScaledPixels, Scene, Shadow, Size, Surface, Underline,
};
use anyhow::{anyhow, Result};
@@ -14,6 +14,7 @@ use cocoa::{
use collections::HashMap;
use core_foundation::base::TCFType;
use foreign_types::ForeignType;
+use futures::channel::oneshot;
use media::core_video::CVMetalTextureCache;
use metal::{CAMetalLayer, CommandQueue, MTLPixelFormat, MTLResourceOptions, NSRange};
use objc::{self, msg_send, sel, sel_impl};
@@ -105,6 +106,7 @@ pub(crate) struct MetalRenderer {
instance_buffer_pool: Arc<Mutex<InstanceBufferPool>>,
sprite_atlas: Arc<MetalAtlas>,
core_video_texture_cache: CVMetalTextureCache,
+ fps_counter: Arc<FpsCounter>,
}
impl MetalRenderer {
@@ -250,6 +252,7 @@ impl MetalRenderer {
instance_buffer_pool,
sprite_atlas,
core_video_texture_cache,
+ fps_counter: FpsCounter::new(),
}
}
@@ -292,7 +295,8 @@ impl MetalRenderer {
// nothing to do
}
- pub fn draw(&mut self, scene: &Scene) {
+ pub fn draw(&mut self, scene: &Scene, on_complete: Option<oneshot::Sender<()>>) {
+ let on_complete = Arc::new(Mutex::new(on_complete));
let layer = self.layer.clone();
let viewport_size = layer.drawable_size();
let viewport_size: Size<DevicePixels> = size(
@@ -319,13 +323,24 @@ impl MetalRenderer {
Ok(command_buffer) => {
let instance_buffer_pool = self.instance_buffer_pool.clone();
let instance_buffer = Cell::new(Some(instance_buffer));
- let block = ConcreteBlock::new(move |_| {
- if let Some(instance_buffer) = instance_buffer.take() {
- instance_buffer_pool.lock().release(instance_buffer);
- }
- });
- let block = block.copy();
- command_buffer.add_completed_handler(&block);
+ let device = self.device.clone();
+ let fps_counter = self.fps_counter.clone();
+ let completed_handler =
+ ConcreteBlock::new(move |_: &metal::CommandBufferRef| {
+ let mut cpu_timestamp = 0;
+ let mut gpu_timestamp = 0;
+ device.sample_timestamps(&mut cpu_timestamp, &mut gpu_timestamp);
+
+ fps_counter.increment(gpu_timestamp);
+ if let Some(on_complete) = on_complete.lock().take() {
+ on_complete.send(()).ok();
+ }
+ if let Some(instance_buffer) = instance_buffer.take() {
+ instance_buffer_pool.lock().release(instance_buffer);
+ }
+ });
+ let completed_handler = completed_handler.copy();
+ command_buffer.add_completed_handler(&completed_handler);
if self.presents_with_transaction {
command_buffer.commit();
@@ -1117,6 +1132,10 @@ impl MetalRenderer {
}
true
}
+
+ pub fn fps(&self) -> f32 {
+ self.fps_counter.fps()
+ }
}
fn build_pipeline_state(
@@ -784,14 +784,14 @@ impl PlatformWindow for MacWindow {
self.0.as_ref().lock().bounds()
}
- fn window_bounds(&self) -> WindowBounds {
- self.0.as_ref().lock().window_bounds()
- }
-
fn is_maximized(&self) -> bool {
self.0.as_ref().lock().is_maximized()
}
+ fn window_bounds(&self) -> WindowBounds {
+ self.0.as_ref().lock().window_bounds()
+ }
+
fn content_size(&self) -> Size<Pixels> {
self.0.as_ref().lock().content_size()
}
@@ -975,8 +975,6 @@ impl PlatformWindow for MacWindow {
}
}
- fn set_app_id(&mut self, _app_id: &str) {}
-
fn set_background_appearance(&self, background_appearance: WindowBackgroundAppearance) {
let mut this = self.0.as_ref().lock();
this.renderer
@@ -1007,30 +1005,6 @@ impl PlatformWindow for MacWindow {
}
}
- fn set_edited(&mut self, edited: bool) {
- unsafe {
- let window = self.0.lock().native_window;
- msg_send![window, setDocumentEdited: edited as BOOL]
- }
-
- // Changing the document edited state resets the traffic light position,
- // so we have to move it again.
- self.0.lock().move_traffic_light();
- }
-
- fn show_character_palette(&self) {
- let this = self.0.lock();
- let window = this.native_window;
- this.executor
- .spawn(async move {
- unsafe {
- let app = NSApplication::sharedApplication(nil);
- let _: () = msg_send![app, orderFrontCharacterPalette: window];
- }
- })
- .detach();
- }
-
fn minimize(&self) {
let window = self.0.lock().native_window;
unsafe {
@@ -1107,18 +1081,48 @@ impl PlatformWindow for MacWindow {
self.0.lock().appearance_changed_callback = Some(callback);
}
- fn draw(&self, scene: &crate::Scene) {
+ fn draw(&self, scene: &crate::Scene, on_complete: Option<oneshot::Sender<()>>) {
let mut this = self.0.lock();
- this.renderer.draw(scene);
+ this.renderer.draw(scene, on_complete);
}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> {
self.0.lock().renderer.sprite_atlas().clone()
}
+ fn set_edited(&mut self, edited: bool) {
+ unsafe {
+ let window = self.0.lock().native_window;
+ msg_send![window, setDocumentEdited: edited as BOOL]
+ }
+
+ // Changing the document edited state resets the traffic light position,
+ // so we have to move it again.
+ self.0.lock().move_traffic_light();
+ }
+
+ fn show_character_palette(&self) {
+ let this = self.0.lock();
+ let window = this.native_window;
+ this.executor
+ .spawn(async move {
+ unsafe {
+ let app = NSApplication::sharedApplication(nil);
+ let _: () = msg_send![app, orderFrontCharacterPalette: window];
+ }
+ })
+ .detach();
+ }
+
+ fn set_app_id(&mut self, _app_id: &str) {}
+
fn gpu_specs(&self) -> Option<crate::GPUSpecs> {
None
}
+
+ fn fps(&self) -> Option<f32> {
+ Some(self.0.lock().renderer.fps())
+ }
}
impl rwh::HasWindowHandle for MacWindow {
@@ -251,7 +251,12 @@ impl PlatformWindow for TestWindow {
fn on_appearance_changed(&self, _callback: Box<dyn FnMut()>) {}
- fn draw(&self, _scene: &crate::Scene) {}
+ fn draw(
+ &self,
+ _scene: &crate::Scene,
+ _on_complete: Option<futures::channel::oneshot::Sender<()>>,
+ ) {
+ }
fn sprite_atlas(&self) -> sync::Arc<dyn crate::PlatformAtlas> {
self.0.lock().sprite_atlas.clone()
@@ -277,6 +282,10 @@ impl PlatformWindow for TestWindow {
fn gpu_specs(&self) -> Option<GPUSpecs> {
None
}
+
+ fn fps(&self) -> Option<f32> {
+ None
+ }
}
pub(crate) struct TestAtlasState {
@@ -660,8 +660,8 @@ impl PlatformWindow for WindowsWindow {
self.0.state.borrow_mut().callbacks.appearance_changed = Some(callback);
}
- fn draw(&self, scene: &Scene) {
- self.0.state.borrow_mut().renderer.draw(scene)
+ fn draw(&self, scene: &Scene, on_complete: Option<oneshot::Sender<()>>) {
+ self.0.state.borrow_mut().renderer.draw(scene, on_complete)
}
fn sprite_atlas(&self) -> Arc<dyn PlatformAtlas> {
@@ -675,6 +675,10 @@ impl PlatformWindow for WindowsWindow {
fn gpu_specs(&self) -> Option<GPUSpecs> {
Some(self.0.state.borrow().renderer.gpu_specs())
}
+
+ fn fps(&self) -> Option<f32> {
+ None
+ }
}
#[implement(IDropTarget)]
@@ -11,9 +11,9 @@ use crate::{
PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams,
Replay, ResizeEdge, ScaledPixels, Scene, Shadow, SharedString, Size, StrikethroughStyle, Style,
SubscriberSet, Subscription, TaffyLayoutEngine, Task, TextStyle, TextStyleRefinement,
- TransformationMatrix, Underline, UnderlineStyle, View, VisualContext, WeakView,
- WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations,
- WindowOptions, WindowParams, WindowTextSystem, SUBPIXEL_VARIANTS,
+ TimeToFirstWindowDraw, TransformationMatrix, Underline, UnderlineStyle, View, VisualContext,
+ WeakView, WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls,
+ WindowDecorations, WindowOptions, WindowParams, WindowTextSystem, SUBPIXEL_VARIANTS,
};
use anyhow::{anyhow, Context as _, Result};
use collections::{FxHashMap, FxHashSet};
@@ -544,6 +544,8 @@ pub struct Window {
hovered: Rc<Cell<bool>>,
pub(crate) dirty: Rc<Cell<bool>>,
pub(crate) needs_present: Rc<Cell<bool>>,
+ /// We assign this to be notified when the platform graphics backend fires the next completion callback for drawing the window.
+ present_completed: RefCell<Option<oneshot::Sender<()>>>,
pub(crate) last_input_timestamp: Rc<Cell<Instant>>,
pub(crate) refreshing: bool,
pub(crate) draw_phase: DrawPhase,
@@ -820,6 +822,7 @@ impl Window {
hovered,
dirty,
needs_present,
+ present_completed: RefCell::default(),
last_input_timestamp,
refreshing: false,
draw_phase: DrawPhase::None,
@@ -1489,13 +1492,29 @@ impl<'a> WindowContext<'a> {
self.window.refreshing = false;
self.window.draw_phase = DrawPhase::None;
self.window.needs_present.set(true);
+
+ if let Some(TimeToFirstWindowDraw::Pending(start)) = self.app.time_to_first_window_draw {
+ let (tx, rx) = oneshot::channel();
+ *self.window.present_completed.borrow_mut() = Some(tx);
+ self.spawn(|mut cx| async move {
+ rx.await.ok();
+ cx.update(|cx| {
+ let duration = start.elapsed();
+ cx.time_to_first_window_draw = Some(TimeToFirstWindowDraw::Done(duration));
+ log::info!("time to first window draw: {:?}", duration);
+ cx.push_effect(Effect::Refresh);
+ })
+ })
+ .detach();
+ }
}
#[profiling::function]
fn present(&self) {
+ let on_complete = self.window.present_completed.take();
self.window
.platform_window
- .draw(&self.window.rendered_frame.scene);
+ .draw(&self.window.rendered_frame.scene, on_complete);
self.window.needs_present.set(false);
profiling::finish_frame!();
}
@@ -3718,6 +3737,12 @@ impl<'a> WindowContext<'a> {
pub fn gpu_specs(&self) -> Option<GPUSpecs> {
self.window.platform_window.gpu_specs()
}
+
+ /// Get the current FPS (frames per second) of the window.
+ /// This is only supported on macOS currently.
+ pub fn fps(&self) -> Option<f32> {
+ self.window.platform_window.fps()
+ }
}
#[cfg(target_os = "windows")]
@@ -54,6 +54,10 @@ pub struct LanguageModelCacheConfiguration {
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
+ /// If None, falls back to [LanguageModelProvider::icon]
+ fn icon(&self) -> Option<IconName> {
+ None
+ }
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
@@ -2,6 +2,7 @@ use proto::Plan;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use strum::EnumIter;
+use ui::IconName;
use crate::LanguageModelAvailability;
@@ -65,6 +66,13 @@ impl CloudModel {
}
}
+ pub fn icon(&self) -> Option<IconName> {
+ match self {
+ Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
+ _ => None,
+ }
+ }
+
pub fn max_token_count(&self) -> usize {
match self {
Self::Anthropic(model) => model.max_token_count(),
@@ -19,7 +19,7 @@ use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
-use ui::{prelude::*, Icon, IconName};
+use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
const PROVIDER_ID: &str = "anthropic";
@@ -29,15 +29,22 @@ const PROVIDER_NAME: &str = "Anthropic";
pub struct AnthropicSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
+ /// Extend Zed's list of Anthropic models.
pub available_models: Vec<AvailableModel>,
pub needs_setting_migration: bool,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
+ /// The model's name in the Anthropic API. e.g. claude-3-5-sonnet-20240620
pub name: String,
+ /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
+ pub display_name: Option<String>,
+ /// The model's context window size.
pub max_tokens: usize,
+ /// A model `name` to substitute when calling tools, in case the primary model doesn't support tool calling.
pub tool_override: Option<String>,
+ /// Configuration of Anthropic's caching API.
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
pub max_output_tokens: Option<u32>,
}
@@ -47,8 +54,11 @@ pub struct AnthropicLanguageModelProvider {
state: gpui::Model<State>,
}
+const ANTHROPIC_API_KEY_VAR: &'static str = "ANTHROPIC_API_KEY";
+
pub struct State {
api_key: Option<String>,
+ api_key_from_env: bool,
_subscription: Subscription,
}
@@ -60,6 +70,7 @@ impl State {
delete_credentials.await.ok();
this.update(&mut cx, |this, cx| {
this.api_key = None;
+ this.api_key_from_env = false;
cx.notify();
})
})
@@ -98,18 +109,20 @@ impl State {
.clone();
cx.spawn(|this, mut cx| async move {
- let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
- api_key
+ let (api_key, from_env) = if let Ok(api_key) = std::env::var(ANTHROPIC_API_KEY_VAR)
+ {
+ (api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
- String::from_utf8(api_key)?
+ (String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
+ this.api_key_from_env = from_env;
cx.notify();
})
})
@@ -121,6 +134,7 @@ impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
+ api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
@@ -171,6 +185,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
model.name.clone(),
anthropic::Model::Custom {
name: model.name.clone(),
+ display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
cache_configuration: model.cache_configuration.as_ref().map(|config| {
@@ -529,6 +544,8 @@ impl Render for ConfigurationView {
"Paste your Anthropic API key below and hit enter to use the assistant:",
];
+ let env_var_set = self.state.read(cx).api_key_from_env;
+
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
} else if self.should_render_editor(cx) {
@@ -550,7 +567,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- "You can also assign the ANTHROPIC_API_KEY environment variable and restart Zed.",
+ "You can also assign the {ANTHROPIC_API_KEY_VAR} environment variable and restart Zed.",
)
.size(LabelSize::Small),
)
@@ -563,13 +580,21 @@ impl Render for ConfigurationView {
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
- .child(Label::new("API key configured.")),
+ .child(Label::new(if env_var_set {
+ format!("API key set in {ANTHROPIC_API_KEY_VAR} environment variable.")
+ } else {
+ "API key configured.".to_string()
+ })),
)
.child(
Button::new("reset-key", "Reset key")
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {ANTHROPIC_API_KEY_VAR} environment variable."), cx))
+ })
.on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
)
.into_any()
@@ -52,12 +52,20 @@ pub enum AvailableProvider {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
- provider: AvailableProvider,
- name: String,
- max_tokens: usize,
- tool_override: Option<String>,
- cache_configuration: Option<LanguageModelCacheConfiguration>,
- max_output_tokens: Option<u32>,
+ /// The provider of the language model.
+ pub provider: AvailableProvider,
+ /// The model's name in the provider's API. e.g. claude-3-5-sonnet-20240620
+ pub name: String,
+ /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
+ pub display_name: Option<String>,
+ /// The size of the context window, indicating the maximum number of tokens the model can process.
+ pub max_tokens: usize,
+ /// The maximum number of output tokens allowed by the model.
+ pub max_output_tokens: Option<u32>,
+ /// Override this model with a different Anthropic model for tool calls.
+ pub tool_override: Option<String>,
+ /// Indicates whether this custom model supports caching.
+ pub cache_configuration: Option<LanguageModelCacheConfiguration>,
}
pub struct CloudLanguageModelProvider {
@@ -202,6 +210,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::Anthropic => {
CloudModel::Anthropic(anthropic::Model::Custom {
name: model.name.clone(),
+ display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
cache_configuration: model.cache_configuration.as_ref().map(|config| {
@@ -389,6 +398,10 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
+ fn icon(&self) -> Option<IconName> {
+ self.model.icon()
+ }
+
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
@@ -180,6 +180,7 @@ impl LanguageModel for CopilotChatLanguageModel {
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
let model = match self.model {
+ CopilotChatModel::Gpt4o => open_ai::Model::FourOmni,
CopilotChatModel::Gpt4 => open_ai::Model::Four,
CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo,
};
@@ -14,7 +14,7 @@ use settings::{Settings, SettingsStore};
use std::{future, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
-use ui::{prelude::*, Icon, IconName};
+use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use crate::{
@@ -46,9 +46,12 @@ pub struct GoogleLanguageModelProvider {
pub struct State {
api_key: Option<String>,
+ api_key_from_env: bool,
_subscription: Subscription,
}
+const GOOGLE_AI_API_KEY_VAR: &'static str = "GOOGLE_AI_API_KEY";
+
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
@@ -61,6 +64,7 @@ impl State {
delete_credentials.await.ok();
this.update(&mut cx, |this, cx| {
this.api_key = None;
+ this.api_key_from_env = false;
cx.notify();
})
})
@@ -90,18 +94,20 @@ impl State {
.clone();
cx.spawn(|this, mut cx| async move {
- let api_key = if let Ok(api_key) = std::env::var("GOOGLE_AI_API_KEY") {
- api_key
+ let (api_key, from_env) = if let Ok(api_key) = std::env::var(GOOGLE_AI_API_KEY_VAR)
+ {
+ (api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
- String::from_utf8(api_key)?
+ (String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
+ this.api_key_from_env = from_env;
cx.notify();
})
})
@@ -113,6 +119,7 @@ impl GoogleLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
+ api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
@@ -422,6 +429,8 @@ impl Render for ConfigurationView {
"Paste your Google AI API key below and hit enter to use the assistant:",
];
+ let env_var_set = self.state.read(cx).api_key_from_env;
+
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
} else if self.should_render_editor(cx) {
@@ -443,7 +452,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- "You can also assign the GOOGLE_AI_API_KEY environment variable and restart Zed.",
+ format!("You can also assign the {GOOGLE_AI_API_KEY_VAR} environment variable and restart Zed."),
)
.size(LabelSize::Small),
)
@@ -456,13 +465,21 @@ impl Render for ConfigurationView {
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
- .child(Label::new("API key configured.")),
+ .child(Label::new(if env_var_set {
+ format!("API key set in {GOOGLE_AI_API_KEY_VAR} environment variable.")
+ } else {
+ "API key configured.".to_string()
+ })),
)
.child(
Button::new("reset-key", "Reset key")
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {GOOGLE_AI_API_KEY_VAR} environment variable."), cx))
+ })
.on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
)
.into_any()
@@ -16,7 +16,7 @@ use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
-use ui::{prelude::*, Icon, IconName};
+use ui::{prelude::*, Icon, IconName, Tooltip};
use util::ResultExt;
use crate::{
@@ -49,9 +49,12 @@ pub struct OpenAiLanguageModelProvider {
pub struct State {
api_key: Option<String>,
+ api_key_from_env: bool,
_subscription: Subscription,
}
+const OPENAI_API_KEY_VAR: &'static str = "OPENAI_API_KEY";
+
impl State {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
@@ -64,6 +67,7 @@ impl State {
delete_credentials.await.log_err();
this.update(&mut cx, |this, cx| {
this.api_key = None;
+ this.api_key_from_env = false;
cx.notify();
})
})
@@ -92,17 +96,18 @@ impl State {
.api_url
.clone();
cx.spawn(|this, mut cx| async move {
- let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
- api_key
+ let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
+ (api_key, true)
} else {
let (_, api_key) = cx
.update(|cx| cx.read_credentials(&api_url))?
.await?
.ok_or_else(|| anyhow!("credentials not found"))?;
- String::from_utf8(api_key)?
+ (String::from_utf8(api_key)?, false)
};
this.update(&mut cx, |this, cx| {
this.api_key = Some(api_key);
+ this.api_key_from_env = from_env;
cx.notify();
})
})
@@ -114,6 +119,7 @@ impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
+ api_key_from_env: false,
_subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
@@ -476,6 +482,8 @@ impl Render for ConfigurationView {
"Paste your OpenAI API key below and hit enter to use the assistant:",
];
+ let env_var_set = self.state.read(cx).api_key_from_env;
+
if self.load_credentials_task.is_some() {
div().child(Label::new("Loading credentials...")).into_any()
} else if self.should_render_editor(cx) {
@@ -497,7 +505,7 @@ impl Render for ConfigurationView {
)
.child(
Label::new(
- "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
+ format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
)
.size(LabelSize::Small),
)
@@ -510,13 +518,21 @@ impl Render for ConfigurationView {
h_flex()
.gap_1()
.child(Icon::new(IconName::Check).color(Color::Success))
- .child(Label::new("API key configured.")),
+ .child(Label::new(if env_var_set {
+ format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
+ } else {
+ "API key configured.".to_string()
+ })),
)
.child(
Button::new("reset-key", "Reset key")
.icon(Some(IconName::Trash))
.icon_size(IconSize::Small)
.icon_position(IconPosition::Start)
+ .disabled(env_var_set)
+ .when(env_var_set, |this| {
+ this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable."), cx))
+ })
.on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
)
.into_any()
@@ -94,12 +94,14 @@ impl AnthropicSettingsContent {
.filter_map(|model| match model {
anthropic::Model::Custom {
name,
+ display_name,
max_tokens,
tool_override,
cache_configuration,
max_output_tokens,
} => Some(provider::anthropic::AvailableModel {
name,
+ display_name,
max_tokens,
tool_override,
cache_configuration: cache_configuration.as_ref().map(
@@ -193,15 +193,28 @@ pub fn prompts_dir() -> &'static PathBuf {
/// Returns the path to the prompt templates directory.
///
/// This is where the prompt templates for core features can be overridden with templates.
-pub fn prompt_overrides_dir() -> &'static PathBuf {
- static PROMPT_TEMPLATES_DIR: OnceLock<PathBuf> = OnceLock::new();
- PROMPT_TEMPLATES_DIR.get_or_init(|| {
- if cfg!(target_os = "macos") {
- config_dir().join("prompt_overrides")
- } else {
- support_dir().join("prompt_overrides")
+///
+/// # Arguments
+///
+/// * `dev_mode` - If true, assumes the current working directory is the Zed repository.
+pub fn prompt_overrides_dir(repo_path: Option<&Path>) -> PathBuf {
+ if let Some(path) = repo_path {
+ let dev_path = path.join("assets").join("prompts");
+ if dev_path.exists() {
+ return dev_path;
}
- })
+ }
+
+ static PROMPT_TEMPLATES_DIR: OnceLock<PathBuf> = OnceLock::new();
+ PROMPT_TEMPLATES_DIR
+ .get_or_init(|| {
+ if cfg!(target_os = "macos") {
+ config_dir().join("prompt_overrides")
+ } else {
+ support_dir().join("prompt_overrides")
+ }
+ })
+ .clone()
}
/// Returns the path to the semantic search's embeddings directory.
@@ -0,0 +1,36 @@
+[package]
+name = "performance"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/performance.rs"
+doctest = false
+
+[features]
+test-support = [
+ "collections/test-support",
+ "gpui/test-support",
+ "workspace/test-support",
+]
+
+[dependencies]
+anyhow.workspace = true
+gpui.workspace = true
+log.workspace = true
+schemars.workspace = true
+serde.workspace = true
+settings.workspace = true
+workspace.workspace = true
+
+[dev-dependencies]
+collections = { workspace = true, features = ["test-support"] }
+gpui = { workspace = true, features = ["test-support"] }
+settings = { workspace = true, features = ["test-support"] }
+util = { workspace = true, features = ["test-support"] }
+workspace = { workspace = true, features = ["test-support"] }
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,189 @@
+use std::time::Instant;
+
+use anyhow::Result;
+use gpui::{
+ div, AppContext, InteractiveElement as _, Render, StatefulInteractiveElement as _,
+ Subscription, ViewContext, VisualContext,
+};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsSources, SettingsStore};
+use workspace::{
+ ui::{Label, LabelCommon, LabelSize, Tooltip},
+ ItemHandle, StatusItemView, Workspace,
+};
+
+const SHOW_STARTUP_TIME_DURATION: std::time::Duration = std::time::Duration::from_secs(5);
+
+pub fn init(cx: &mut AppContext) {
+ PerformanceSettings::register(cx);
+
+ let mut enabled = PerformanceSettings::get_global(cx)
+ .show_in_status_bar
+ .unwrap_or(false);
+ let start_time = Instant::now();
+ let mut _observe_workspaces = toggle_status_bar_items(enabled, start_time, cx);
+
+ cx.observe_global::<SettingsStore>(move |cx| {
+ let new_value = PerformanceSettings::get_global(cx)
+ .show_in_status_bar
+ .unwrap_or(false);
+ if new_value != enabled {
+ enabled = new_value;
+ _observe_workspaces = toggle_status_bar_items(enabled, start_time, cx);
+ }
+ })
+ .detach();
+}
+
+fn toggle_status_bar_items(
+ enabled: bool,
+ start_time: Instant,
+ cx: &mut AppContext,
+) -> Option<Subscription> {
+ for window in cx.windows() {
+ if let Some(workspace) = window.downcast::<Workspace>() {
+ workspace
+ .update(cx, |workspace, cx| {
+ toggle_status_bar_item(workspace, enabled, start_time, cx);
+ })
+ .ok();
+ }
+ }
+
+ if enabled {
+ log::info!("performance metrics display enabled");
+ Some(cx.observe_new_views::<Workspace>(move |workspace, cx| {
+ toggle_status_bar_item(workspace, true, start_time, cx);
+ }))
+ } else {
+ log::info!("performance metrics display disabled");
+ None
+ }
+}
+
+struct PerformanceStatusBarItem {
+ display_mode: DisplayMode,
+}
+
+#[derive(Copy, Clone, Debug)]
+enum DisplayMode {
+ StartupTime,
+ Fps,
+}
+
+impl PerformanceStatusBarItem {
+ fn new(start_time: Instant, cx: &mut ViewContext<Self>) -> Self {
+ let now = Instant::now();
+ let display_mode = if now < start_time + SHOW_STARTUP_TIME_DURATION {
+ DisplayMode::StartupTime
+ } else {
+ DisplayMode::Fps
+ };
+
+ let this = Self { display_mode };
+
+ if let DisplayMode::StartupTime = display_mode {
+ cx.spawn(|this, mut cx| async move {
+ let now = Instant::now();
+ let remaining_duration =
+ (start_time + SHOW_STARTUP_TIME_DURATION).saturating_duration_since(now);
+ cx.background_executor().timer(remaining_duration).await;
+ this.update(&mut cx, |this, cx| {
+ this.display_mode = DisplayMode::Fps;
+ cx.notify();
+ })
+ .ok();
+ })
+ .detach();
+ }
+
+ this
+ }
+}
+
+impl Render for PerformanceStatusBarItem {
+ fn render(&mut self, cx: &mut gpui::ViewContext<Self>) -> impl gpui::IntoElement {
+ let text = match self.display_mode {
+ DisplayMode::StartupTime => cx
+ .time_to_first_window_draw()
+ .map_or("Pending".to_string(), |duration| {
+ format!("{}ms", duration.as_millis())
+ }),
+ DisplayMode::Fps => cx.fps().map_or("".to_string(), |fps| {
+ format!("{:3} FPS", fps.round() as u32)
+ }),
+ };
+
+ use gpui::ParentElement;
+ let display_mode = self.display_mode;
+ div()
+ .id("performance status")
+ .child(Label::new(text).size(LabelSize::Small))
+ .tooltip(move |cx| match display_mode {
+ DisplayMode::StartupTime => Tooltip::text("Time to first window draw", cx),
+ DisplayMode::Fps => cx
+ .new_view(|cx| {
+ let tooltip = Tooltip::new("Current FPS");
+ if let Some(time_to_first) = cx.time_to_first_window_draw() {
+ tooltip.meta(format!(
+ "Time to first window draw: {}ms",
+ time_to_first.as_millis()
+ ))
+ } else {
+ tooltip
+ }
+ })
+ .into(),
+ })
+ }
+}
+
+impl StatusItemView for PerformanceStatusBarItem {
+ fn set_active_pane_item(
+ &mut self,
+ _active_pane_item: Option<&dyn ItemHandle>,
+ _cx: &mut gpui::ViewContext<Self>,
+ ) {
+ // This is not currently used.
+ }
+}
+
+fn toggle_status_bar_item(
+ workspace: &mut Workspace,
+ enabled: bool,
+ start_time: Instant,
+ cx: &mut ViewContext<Workspace>,
+) {
+ if enabled {
+ workspace.status_bar().update(cx, |bar, cx| {
+ bar.add_right_item(
+ cx.new_view(|cx| PerformanceStatusBarItem::new(start_time, cx)),
+ cx,
+ )
+ });
+ } else {
+ workspace.status_bar().update(cx, |bar, cx| {
+ bar.remove_items_of_type::<PerformanceStatusBarItem>(cx);
+ });
+ }
+}
+
+/// Configuration of the display of performance details.
+#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
+pub struct PerformanceSettings {
+ /// Display the time to first window draw and frame rate in the status bar.
+ ///
+ /// Default: false
+ pub show_in_status_bar: Option<bool>,
+}
+
+impl Settings for PerformanceSettings {
+ const KEY: Option<&'static str> = Some("performance");
+
+ type FileContent = Self;
+
+ fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
+ sources.json_merge()
+ }
+}
@@ -51,6 +51,15 @@ pub struct Picker<D: PickerDelegate> {
is_modal: bool,
}
+#[derive(Debug, Default, Clone, Copy, PartialEq)]
+pub enum PickerEditorPosition {
+ #[default]
+ /// Render the editor at the start of the picker. Usually the top
+ Start,
+ /// Render the editor at the end of the picker. Usually the bottom
+ End,
+}
+
pub trait PickerDelegate: Sized + 'static {
type ListItem: IntoElement;
@@ -103,8 +112,16 @@ pub trait PickerDelegate: Sized + 'static {
None
}
+ fn editor_position(&self) -> PickerEditorPosition {
+ PickerEditorPosition::default()
+ }
+
fn render_editor(&self, editor: &View<Editor>, _cx: &mut ViewContext<Picker<Self>>) -> Div {
v_flex()
+ .when(
+ self.editor_position() == PickerEditorPosition::End,
+ |this| this.child(Divider::horizontal()),
+ )
.child(
h_flex()
.overflow_hidden()
@@ -113,7 +130,10 @@ pub trait PickerDelegate: Sized + 'static {
.px_3()
.child(editor.clone()),
)
- .child(Divider::horizontal())
+ .when(
+ self.editor_position() == PickerEditorPosition::Start,
+ |this| this.child(Divider::horizontal()),
+ )
}
fn render_match(
@@ -504,7 +524,7 @@ impl<D: PickerDelegate> Picker<D> {
picker
.border_color(cx.theme().colors().border_variant)
.border_b_1()
- .pb(px(-1.0))
+ .py(px(-1.0))
},
)
}
@@ -555,6 +575,8 @@ impl<D: PickerDelegate> ModalView for Picker<D> {}
impl<D: PickerDelegate> Render for Picker<D> {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let editor_position = self.delegate.editor_position();
+
v_flex()
.key_context("Picker")
.size_full()
@@ -574,9 +596,15 @@ impl<D: PickerDelegate> Render for Picker<D> {
.on_action(cx.listener(Self::secondary_confirm))
.on_action(cx.listener(Self::confirm_completion))
.on_action(cx.listener(Self::confirm_input))
- .child(match &self.head {
- Head::Editor(editor) => self.delegate.render_editor(&editor.clone(), cx),
- Head::Empty(empty_head) => div().child(empty_head.clone()),
+ .children(match &self.head {
+ Head::Editor(editor) => {
+ if editor_position == PickerEditorPosition::Start {
+ Some(self.delegate.render_editor(&editor.clone(), cx))
+ } else {
+ None
+ }
+ }
+ Head::Empty(empty_head) => Some(div().child(empty_head.clone())),
})
.when(self.delegate.match_count() > 0, |el| {
el.child(
@@ -602,5 +630,15 @@ impl<D: PickerDelegate> Render for Picker<D> {
)
})
.children(self.delegate.render_footer(cx))
+ .children(match &self.head {
+ Head::Editor(editor) => {
+ if editor_position == PickerEditorPosition::End {
+ Some(self.delegate.render_editor(&editor.clone(), cx))
+ } else {
+ None
+ }
+ }
+ Head::Empty(empty_head) => Some(div().child(empty_head.clone())),
+ })
}
}
@@ -141,7 +141,7 @@ impl Render for QuickActionBar {
let assistant_button = QuickActionBarButton::new(
"toggle inline assistant",
- IconName::MagicWand,
+ IconName::ZedAssistant,
false,
Box::new(InlineAssist::default()),
"Inline Assist",
@@ -943,26 +943,31 @@ impl Item for TerminalView {
fn tab_content(&self, params: TabContentParams, cx: &WindowContext) -> AnyElement {
let terminal = self.terminal().read(cx);
let title = terminal.title(true);
+ let rerun_button = |task_id: task::TaskId| {
+ IconButton::new("rerun-icon", IconName::Rerun)
+ .icon_size(IconSize::Small)
+ .size(ButtonSize::Compact)
+ .icon_color(Color::Default)
+ .shape(ui::IconButtonShape::Square)
+ .tooltip(|cx| Tooltip::text("Rerun task", cx))
+ .on_click(move |_, cx| {
+ cx.dispatch_action(Box::new(tasks_ui::Rerun {
+ task_id: Some(task_id.clone()),
+ ..tasks_ui::Rerun::default()
+ }));
+ })
+ };
let (icon, icon_color, rerun_button) = match terminal.task() {
Some(terminal_task) => match &terminal_task.status {
- TaskStatus::Unknown => (IconName::ExclamationTriangle, Color::Warning, None),
TaskStatus::Running => (IconName::Play, Color::Disabled, None),
+ TaskStatus::Unknown => (
+ IconName::ExclamationTriangle,
+ Color::Warning,
+ Some(rerun_button(terminal_task.id.clone())),
+ ),
TaskStatus::Completed { success } => {
- let task_id = terminal_task.id.clone();
- let rerun_button = IconButton::new("rerun-icon", IconName::Rerun)
- .icon_size(IconSize::Small)
- .size(ButtonSize::Compact)
- .icon_color(Color::Default)
- .shape(ui::IconButtonShape::Square)
- .tooltip(|cx| Tooltip::text("Rerun task", cx))
- .on_click(move |_, cx| {
- cx.dispatch_action(Box::new(tasks_ui::Rerun {
- task_id: Some(task_id.clone()),
- ..Default::default()
- }));
- });
-
+ let rerun_button = rerun_button(terminal_task.id.clone());
if *success {
(IconName::Check, Color::Success, Some(rerun_button))
} else {
@@ -249,6 +249,7 @@ pub struct ThemeSettingsContent {
pub ui_font_fallbacks: Option<Vec<String>>,
/// The OpenType features to enable for text in the UI.
#[serde(default)]
+ #[schemars(default = "default_font_features")]
pub ui_font_features: Option<FontFeatures>,
/// The weight of the UI font in CSS units from 100 to 900.
#[serde(default)]
@@ -270,6 +271,7 @@ pub struct ThemeSettingsContent {
pub buffer_line_height: Option<BufferLineHeight>,
/// The OpenType features to enable for rendering in text buffers.
#[serde(default)]
+ #[schemars(default = "default_font_features")]
pub buffer_font_features: Option<FontFeatures>,
/// The name of the Zed theme to use.
#[serde(default)]
@@ -288,6 +290,10 @@ pub struct ThemeSettingsContent {
pub theme_overrides: Option<ThemeStyleContent>,
}
+fn default_font_features() -> Option<FontFeatures> {
+ Some(FontFeatures::default())
+}
+
impl ThemeSettingsContent {
/// Sets the theme for the given appearance to the theme with the specified name.
pub fn set_theme(&mut self, theme_name: String, appearance: Appearance) {
@@ -89,6 +89,7 @@ pub struct Button {
selected_icon: Option<IconName>,
selected_icon_color: Option<Color>,
key_binding: Option<KeyBinding>,
+ alpha: Option<f32>,
}
impl Button {
@@ -113,6 +114,7 @@ impl Button {
selected_icon: None,
selected_icon_color: None,
key_binding: None,
+ alpha: None,
}
}
@@ -181,6 +183,12 @@ impl Button {
self.key_binding = key_binding.into();
self
}
+
+ /// Sets the alpha property of the color of label.
+ pub fn alpha(mut self, alpha: f32) -> Self {
+ self.alpha = Some(alpha);
+ self
+ }
}
impl Selectable for Button {
@@ -409,6 +417,7 @@ impl RenderOnce for Button {
Label::new(label)
.color(label_color)
.size(self.label_size.unwrap_or_default())
+ .when_some(self.alpha, |this, alpha| this.alpha(alpha))
.line_height_style(LineHeightStyle::UiLabel),
)
.children(self.key_binding),
@@ -107,6 +107,7 @@ impl IconSize {
pub enum IconName {
Ai,
AiAnthropic,
+ AiAnthropicHosted,
AiOpenAi,
AiGoogle,
AiOllama,
@@ -157,6 +158,7 @@ pub enum IconName {
Disconnected,
Download,
Ellipsis,
+ EllipsisVertical,
Envelope,
Escape,
ExclamationTriangle,
@@ -195,7 +197,6 @@ pub enum IconName {
LineHeight,
Link,
ListTree,
- MagicWand,
MagnifyingGlass,
MailOpen,
Maximize,
@@ -230,10 +231,13 @@ pub enum IconName {
Save,
Screen,
SearchSelection,
+ SearchCode,
SelectAll,
Server,
Settings,
Shift,
+ Slash,
+ SlashSquare,
Sliders,
SlidersAlt,
Snip,
@@ -272,6 +276,7 @@ impl IconName {
match self {
IconName::Ai => "icons/ai.svg",
IconName::AiAnthropic => "icons/ai_anthropic.svg",
+ IconName::AiAnthropicHosted => "icons/ai_anthropic_hosted.svg",
IconName::AiOpenAi => "icons/ai_open_ai.svg",
IconName::AiGoogle => "icons/ai_google.svg",
IconName::AiOllama => "icons/ai_ollama.svg",
@@ -321,6 +326,7 @@ impl IconName {
IconName::Disconnected => "icons/disconnected.svg",
IconName::Download => "icons/download.svg",
IconName::Ellipsis => "icons/ellipsis.svg",
+ IconName::EllipsisVertical => "icons/ellipsis_vertical.svg",
IconName::Envelope => "icons/feedback.svg",
IconName::Escape => "icons/escape.svg",
IconName::ExclamationTriangle => "icons/warning.svg",
@@ -359,7 +365,6 @@ impl IconName {
IconName::LineHeight => "icons/line_height.svg",
IconName::Link => "icons/link.svg",
IconName::ListTree => "icons/list_tree.svg",
- IconName::MagicWand => "icons/magic_wand.svg",
IconName::MagnifyingGlass => "icons/magnifying_glass.svg",
IconName::MailOpen => "icons/mail_open.svg",
IconName::Maximize => "icons/maximize.svg",
@@ -394,10 +399,13 @@ impl IconName {
IconName::Save => "icons/save.svg",
IconName::Screen => "icons/desktop.svg",
IconName::SearchSelection => "icons/search_selection.svg",
+ IconName::SearchCode => "icons/search_code.svg",
IconName::SelectAll => "icons/select_all.svg",
IconName::Server => "icons/server.svg",
IconName::Settings => "icons/file_icons/settings.svg",
IconName::Shift => "icons/shift.svg",
+ IconName::Slash => "icons/slash.svg",
+ IconName::SlashSquare => "icons/slash_square.svg",
IconName::Sliders => "icons/sliders.svg",
IconName::SlidersAlt => "icons/sliders-alt.svg",
IconName::Snip => "icons/snip.svg",
@@ -373,6 +373,7 @@ pub trait ItemHandle: 'static + Send {
fn dragged_tab_content(&self, params: TabContentParams, cx: &WindowContext) -> AnyElement;
fn project_path(&self, cx: &AppContext) -> Option<ProjectPath>;
fn project_entry_ids(&self, cx: &AppContext) -> SmallVec<[ProjectEntryId; 3]>;
+ fn project_paths(&self, cx: &AppContext) -> SmallVec<[ProjectPath; 3]>;
fn project_item_model_ids(&self, cx: &AppContext) -> SmallVec<[EntityId; 3]>;
fn for_each_project_item(
&self,
@@ -531,6 +532,16 @@ impl<T: Item> ItemHandle for View<T> {
result
}
+ fn project_paths(&self, cx: &AppContext) -> SmallVec<[ProjectPath; 3]> {
+ let mut result = SmallVec::new();
+ self.read(cx).for_each_project_item(cx, &mut |_, item| {
+ if let Some(id) = item.project_path(cx) {
+ result.push(id);
+ }
+ });
+ result
+ }
+
fn project_item_model_ids(&self, cx: &AppContext) -> SmallVec<[EntityId; 3]> {
let mut result = SmallVec::new();
self.read(cx).for_each_project_item(cx, &mut |id, _| {
@@ -920,7 +920,22 @@ impl Pane {
cx: &AppContext,
) -> Option<Box<dyn ItemHandle>> {
self.items.iter().find_map(|item| {
- if item.is_singleton(cx) && item.project_entry_ids(cx).as_slice() == [entry_id] {
+ if item.is_singleton(cx) && (item.project_entry_ids(cx).as_slice() == [entry_id]) {
+ Some(item.boxed_clone())
+ } else {
+ None
+ }
+ })
+ }
+
+ pub fn item_for_path(
+ &self,
+ project_path: ProjectPath,
+ cx: &AppContext,
+ ) -> Option<Box<dyn ItemHandle>> {
+ self.items.iter().find_map(move |item| {
+ if item.is_singleton(cx) && (item.project_path(cx).as_slice() == [project_path.clone()])
+ {
Some(item.boxed_clone())
} else {
None
@@ -153,6 +153,17 @@ impl StatusBar {
cx.notify();
}
+ pub fn remove_items_of_type<T>(&mut self, cx: &mut ViewContext<Self>)
+ where
+ T: 'static + StatusItemView,
+ {
+ self.left_items
+ .retain(|item| item.item_type() != TypeId::of::<T>());
+ self.right_items
+ .retain(|item| item.item_type() != TypeId::of::<T>());
+ cx.notify();
+ }
+
pub fn add_right_item<T>(&mut self, item: View<T>, cx: &mut ViewContext<Self>)
where
T: 'static + StatusItemView,
@@ -38,7 +38,7 @@ use gpui::{
ResizeEdge, Size, Stateful, Subscription, Task, Tiling, View, WeakView, WindowBounds,
WindowHandle, WindowId, WindowOptions,
};
-use item::{
+pub use item::{
FollowableItem, FollowableItemHandle, Item, ItemHandle, ItemSettings, PreviewTabsSettings,
ProjectItem, SerializableItem, SerializableItemHandle, WeakItemHandle,
};
@@ -115,6 +115,8 @@ lazy_static! {
#[derive(Clone, PartialEq)]
pub struct RemoveWorktreeFromProject(pub WorktreeId);
+actions!(assistant, [ShowConfiguration]);
+
actions!(
workspace,
[
@@ -2616,22 +2618,43 @@ impl Workspace {
open_project_item
}
- pub fn is_project_item_open<T>(
+ pub fn find_project_item<T>(
&self,
pane: &View<Pane>,
project_item: &Model<T::Item>,
cx: &AppContext,
- ) -> bool
+ ) -> Option<View<T>>
where
T: ProjectItem,
{
use project::Item as _;
+ let project_item = project_item.read(cx);
+ let entry_id = project_item.entry_id(cx);
+ let project_path = project_item.project_path(cx);
- project_item
- .read(cx)
- .entry_id(cx)
- .and_then(|entry_id| pane.read(cx).item_for_entry(entry_id, cx))
- .and_then(|item| item.downcast::<T>())
+ let mut item = None;
+ if let Some(entry_id) = entry_id {
+ item = pane.read(cx).item_for_entry(entry_id, cx);
+ }
+ if item.is_none() {
+ if let Some(project_path) = project_path {
+ item = pane.read(cx).item_for_path(project_path, cx);
+ }
+ }
+
+ item.and_then(|item| item.downcast::<T>())
+ }
+
+ pub fn is_project_item_open<T>(
+ &self,
+ pane: &View<Pane>,
+ project_item: &Model<T::Item>,
+ cx: &AppContext,
+ ) -> bool
+ where
+ T: ProjectItem,
+ {
+ self.find_project_item::<T>(pane, project_item, cx)
.is_some()
}
@@ -2646,19 +2669,12 @@ impl Workspace {
where
T: ProjectItem,
{
- use project::Item as _;
-
- let entry_id = project_item.read(cx).entry_id(cx);
- if let Some(item) = entry_id
- .and_then(|entry_id| pane.read(cx).item_for_entry(entry_id, cx))
- .and_then(|item| item.downcast())
- {
+ if let Some(item) = self.find_project_item(&pane, &project_item, cx) {
self.activate_item(&item, activate_pane, focus_item, cx);
return item;
}
let item = cx.new_view(|cx| T::for_project_item(self.project().clone(), project_item, cx));
-
let item_id = item.item_id();
let mut destination_index = None;
pane.update(cx, |pane, cx| {
@@ -72,6 +72,7 @@ outline.workspace = true
outline_panel.workspace = true
parking_lot.workspace = true
paths.workspace = true
+performance.workspace = true
profiling.workspace = true
project.workspace = true
project_panel.workspace = true
@@ -266,6 +266,7 @@ fn init_ui(
welcome::init(cx);
settings_ui::init(cx);
extensions_ui::init(cx);
+ performance::init(cx);
cx.observe_global::<SettingsStore>({
let languages = app_state.languages.clone();
@@ -315,6 +316,7 @@ fn init_ui(
}
fn main() {
+ let start_time = std::time::Instant::now();
menu::init();
zed_actions::init();
@@ -326,7 +328,9 @@ fn main() {
init_logger();
log::info!("========== starting zed ==========");
- let app = App::new().with_assets(Assets);
+ let app = App::new()
+ .with_assets(Assets)
+ .measure_time_to_first_window_draw(start_time);
let (installation_id, existing_installation_id_found) = app
.background_executor()
@@ -1018,8 +1018,6 @@ fn open_settings_file(
#[cfg(test)]
mod tests {
- use crate::stdout_is_a_pty;
-
use super::*;
use anyhow::anyhow;
use assets::Assets;
@@ -3487,12 +3485,8 @@ mod tests {
app_state.fs.clone(),
cx,
);
- let prompt_builder = assistant::init(
- app_state.fs.clone(),
- app_state.client.clone(),
- stdout_is_a_pty(),
- cx,
- );
+ let prompt_builder =
+ assistant::init(app_state.fs.clone(), app_state.client.clone(), false, cx);
repl::init(
app_state.fs.clone(),
app_state.client.telemetry().clone(),
@@ -22,7 +22,7 @@
# Using Zed
- [Multibuffers](./multibuffers.md)
-- [Language model integration](./language-model-integration.md)
+- [Assistant](./assistant.md)
- [Code Completions](./completions.md)
- [Channels](./channels.md)
- [Collaboration](./collaboration.md)
@@ -1,11 +1,17 @@
-# Language model integration
+# Assistant
## Assistant Panel
-The assistant panel provides you with a way to interact with large language models. The assistant is good for various tasks, such as generating code, asking questions about existing code, and even writing plaintext, such as emails and documentation. To open the assistant panel, toggle the right dock by using the `workspace: toggle right dock` action in the command palette or by using the `cmd-r` (Mac) or `ctrl-alt-b` (Linux) shortcut.
+The assistant panel provides you with a way to interact with large language models. The assistant is useful for various tasks, such as generating code, asking questions about existing code, and even writing plaintext, such as emails and documentation. To open the assistant panel, toggle the right dock by using the `workspace: toggle right dock` action in the command palette or by using the `cmd-r` (Mac) or `ctrl-alt-b` (Linux) shortcut.
> **Note**: A custom [key binding](./key-bindings.md) can be set to toggle the right dock.
+Once you have configured a provider, you can interact with the provider's language models in a context editor.
+
+To create a new context editor, use the menu in the top right of the assistant panel and select the `New Context` option.
+
+In the context editor, select a model from one of the configured providers, type a message in the `You` block, and submit with `cmd-enter` (or `ctrl-enter` on Linux).
+
## Setup
- [OpenAI API Setup Instructions](#openai)
@@ -15,7 +21,7 @@ The assistant panel provides you with a way to interact with large language mode
- [Google Gemini API Setup Instructions](#google-gemini)
- [GitHub Copilot Chat](#github-copilot)
-### Having a conversation
+### Having a Conversation
The assistant editor in Zed functions similarly to any other editor. You can use custom key bindings and work with multiple cursors, allowing for seamless transitions between coding and engaging in discussions with the language models. However, the assistant editor differs with the inclusion of message blocks. These blocks serve as containers for text that correspond to different roles within the conversation. These roles include:
@@ -45,35 +51,66 @@ If you want to start a new conversation at any time, you can hit `cmd-n` or use
Simple back-and-forth conversations work well with the assistant. However, there may come a time when you want to modify the previous text in the conversation and steer it in a different direction.
-### Editing a conversation
+### Editing a Conversation
The assistant gives you the flexibility to have control over the conversation. You can freely edit any previous text, including the responses from the assistant. If you want to remove a message block entirely, simply place your cursor at the beginning of the block and use the `delete` key. A typical workflow might involve making edits and adjustments throughout the conversation to refine your inquiry or provide additional context. Here's an example:
1. Write text in a `You` block.
-2. Submit the message with `cmd-enter`
-3. Receive an `Assistant` response that doesn't meet your expectations
-4. Cancel the response with `escape`
-5. Erase the content of the `Assistant` message block and remove the block entirely
-6. Add additional context to your original message
-7. Submit the message with `cmd-enter`
+2. Submit the message with `cmd-enter`.
+3. Receive an `Assistant` response that doesn't meet your expectations.
+4. Cancel the response with `escape`.
+5. Erase the content of the `Assistant` message block and remove the block entirely.
+6. Add additional context to your original message.
+7. Submit the message with `cmd-enter`.
-Being able to edit previous messages gives you control over how tokens are used. You don't need to start up a new context to correct a mistake or to add additional context and you don't have to waste tokens by submitting follow-up corrections.
+Being able to edit previous messages gives you control over how tokens are used. You don't need to start up a new context to correct a mistake or to add additional context, and you don't have to waste tokens by submitting follow-up corrections.
Some additional points to keep in mind:
- You are free to change the model type at any point in the conversation.
- You can cycle the role of a message block by clicking on the role, which is useful when you receive a response in an `Assistant` block that you want to edit and send back up as a `You` block.
-### Saving and loading conversations
+### Saving and Loading Conversations
After you submit your first message, a name for your conversation is generated by the language model, and the conversation is automatically saved to your file system in `~/.config/zed/conversations`. You can access and load previous messages by clicking on the hamburger button in the top-left corner of the assistant panel.

-## Inline generation
+### Adding Prompts
+
+You can customize the default prompts used in new context editors by opening the `Prompt Library`.
+
+Open the `Prompt Library` using either the menu in the top right of the assistant panel and choosing the `Prompt Library` option, or by using the `assistant: deploy prompt library` command when the assistant panel is focused.
+
+### Viewing Past Contexts
+
+You can view all previous contexts by opening the `History` tab in the assistant panel.
+
+Open the `History` using the menu in the top right of the assistant panel and choosing `History`.
+
+### Slash Commands
+
+Slash commands enhance the assistant's capabilities. Begin by typing a `/` at the beginning of the line to see a list of available commands:
+
+- `/default`: Inserts the default prompt into the context
+- `/diagnostics`: Injects errors reported by the project's language server into the context
+- `/fetch`: Inserts the content of a webpage and inserts it into the context
+- `/file`: Inserts a single file or a directory of files into the context
+- `/now`: Inserts the current date and time into the context
+- `/prompt`: Adds a custom-configured prompt to the context (see Prompt Library)
+- `/search`: Performs semantic search for content in your project based on natural language
+- `/symbols`: Inserts the current tab's active symbols into the context
+- `/tab`: Inserts the content of the active tab or all open tabs into the context
+- `/terminal`: Inserts a select number of lines of output from the terminal
+
+## Inline Assistant
+
+You can use `ctrl-enter` to open the inline assistant in both a normal editor, within the assistant panel, and even within the terminal panel.
+
+The inline assistant allows you to send the current selection (or the current line) to a language model and modify the selection with the language model's response. You can also perform multiple generation requests in parallel by pressing `ctrl-enter` with multiple cursors, or by pressing `ctrl-enter` with a selection that spans multiple excerpts in a multibuffer.
+
+The inline assistant pulls its context from the assistant panel, allowing you to provide additional instructions or rules for code transformations.
-You can generate and transform text in any editor by selecting text and pressing `ctrl-enter`.
-You can also perform multiple generation requests in parallel by pressing `ctrl-enter` with multiple cursors, or by pressing `ctrl-enter` with a selection that spans multiple excerpts in a multibuffer.
To create a custom keybinding that prefills a prompt, you can add the following format in your keymap:
```json
@@ -90,7 +127,7 @@ To create a custom keybinding that prefills a prompt, you can add the following
]
```
-## Advanced: Overriding prompt templates
+## Advanced: Overriding Prompt Templates
Zed allows you to override the default prompts used for various assistant features by placing custom Handlebars (.hbs) templates in your `~/.config/zed/prompts/templates` directory. The following templates can be overridden:
@@ -165,7 +202,7 @@ Zed allows you to override the default prompts used for various assistant featur
You can customize these templates to better suit your needs while maintaining the core structure and variables used by Zed. Zed will automatically reload your prompt overrides when they change on disk. Consult Zed's assets/prompts directory for current versions you can play with.
-Be sure you want to override these, as you'll miss out on iteration on our built in features. This should be primarily used when developing Zed.
+Be sure you want to override these, as you'll miss out on iteration on our built-in features. This should be primarily used when developing Zed.
## Setup Instructions
@@ -206,7 +243,7 @@ The custom URL here is `http://localhost:11434/v1`.
### Ollama
-Download and install ollama from [ollama.com/download](https://ollama.com/download) (Linux or MacOS) and ensure it's running with `ollama --version`.
+Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`.
You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
@@ -223,7 +260,7 @@ You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI
```
3. In the assistant panel, select one of the Ollama models using the model dropdown.
-4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings:
+4. (Optional) If you want to change the default URL that is used to access the Ollama server, you can do so by adding the following settings:
```json
{
@@ -249,6 +286,6 @@ You can use Gemini 1.5 Pro/Flash with the Zed assistant by choosing it via the m
You can obtain an API key [here](https://aistudio.google.com/app/apikey).
-### GitHub Copilot
+### GitHub Copilot Chat
You can use GitHub Copilot chat with the Zed assistant by choosing it via the model dropdown in the assistant panel.
@@ -88,4 +88,4 @@ You can also add this as a language-specific setting in your `settings.json` to
## See also
-You may also use the Assistant Panel or the Inline Assistant to interact with language models, see [Language Model Integration](language-model-integration.md) documentation for more information.
+You may also use the Assistant Panel or the Inline Assistant to interact with language models, see the [assistant](assistant.md) documentation for more information.
@@ -32,8 +32,9 @@ my-extension/
src/
lib.rs
languages/
- config.toml
- highlights.scm
+ my-language/
+ config.toml
+ highlights.scm
themes/
my-theme.json
```
@@ -0,0 +1,16 @@
+[package]
+name = "perplexity"
+version = "0.1.0"
+edition = "2021"
+license = "Apache-2.0"
+
+[lib]
+path = "src/perplexity.rs"
+crate-type = ["cdylib"]
+
+[lints]
+workspace = true
+
+[dependencies]
+serde = "1"
+zed_extension_api = { path = "../../crates/extension_api" }
@@ -0,0 +1 @@
+../../LICENSE-APACHE
@@ -0,0 +1,12 @@
+id = "perplexity"
+name = "Perplexity"
+version = "0.1.0"
+description = "Ask questions to Perplexity AI directly from Zed"
+authors = ["Zed Industries <support@zed.dev>"]
+repository = "https://github.com/zed-industries/zed-perplexity"
+schema_version = 1
+
+[slash_commands.perplexity]
+description = "Ask a question to Perplexity AI"
+requires_argument = true
+tooltip_text = "Ask Perplexity"
@@ -0,0 +1,158 @@
+use zed::{
+ http_client::HttpMethod,
+ http_client::HttpRequest,
+ serde_json::{self, json},
+};
+use zed_extension_api::{self as zed, http_client::RedirectPolicy, Result};
+
+struct Perplexity;
+
+impl zed::Extension for Perplexity {
+ fn new() -> Self {
+ Self
+ }
+
+ fn run_slash_command(
+ &self,
+ command: zed::SlashCommand,
+ argument: Vec<String>,
+ worktree: Option<&zed::Worktree>,
+ ) -> zed::Result<zed::SlashCommandOutput> {
+ // Check if the command is 'perplexity'
+ if command.name != "perplexity" {
+ return Err("Invalid command. Expected 'perplexity'.".into());
+ }
+
+ let worktree = worktree.ok_or("Worktree is required")?;
+ // Join arguments with space as the query
+ let query = argument.join(" ");
+ if query.is_empty() {
+ return Ok(zed::SlashCommandOutput {
+ text: "Error: Query not provided. Please enter a question or topic.".to_string(),
+ sections: vec![],
+ });
+ }
+
+ // Get the API key from the environment
+ let env_vars = worktree.shell_env();
+ let api_key = env_vars
+ .iter()
+ .find(|(key, _)| key == "PERPLEXITY_API_KEY")
+ .map(|(_, value)| value.clone())
+ .ok_or("PERPLEXITY_API_KEY not found in environment")?;
+
+ // Prepare the request
+ let request = HttpRequest {
+ method: HttpMethod::Post,
+ url: "https://api.perplexity.ai/chat/completions".to_string(),
+ headers: vec![
+ ("Authorization".to_string(), format!("Bearer {}", api_key)),
+ ("Content-Type".to_string(), "application/json".to_string()),
+ ],
+ body: Some(
+ serde_json::to_vec(&json!({
+ "model": "llama-3.1-sonar-small-128k-online",
+ "messages": [{"role": "user", "content": query}],
+ "stream": true,
+ }))
+ .unwrap(),
+ ),
+ redirect_policy: RedirectPolicy::FollowAll,
+ };
+
+ // Make the HTTP request
+ match zed::http_client::fetch_stream(&request) {
+ Ok(stream) => {
+ let mut full_content = String::new();
+ let mut buffer = String::new();
+ while let Ok(Some(chunk)) = stream.next_chunk() {
+ buffer.push_str(&String::from_utf8_lossy(&chunk));
+ for line in buffer.lines() {
+ if let Some(json) = line.strip_prefix("data: ") {
+ if let Ok(event) = serde_json::from_str::<StreamEvent>(json) {
+ if let Some(choice) = event.choices.first() {
+ full_content.push_str(&choice.delta.content);
+ }
+ }
+ }
+ }
+ buffer.clear();
+ }
+ Ok(zed::SlashCommandOutput {
+ text: full_content,
+ sections: vec![],
+ })
+ }
+ Err(e) => Ok(zed::SlashCommandOutput {
+ text: format!("API request failed. Error: {}. API Key: {}", e, api_key),
+ sections: vec![],
+ }),
+ }
+ }
+
+ fn complete_slash_command_argument(
+ &self,
+ _command: zed::SlashCommand,
+ query: Vec<String>,
+ ) -> zed::Result<Vec<zed::SlashCommandArgumentCompletion>> {
+ let suggestions = vec!["How do I develop a Zed extension?"];
+ let query = query.join(" ").to_lowercase();
+
+ Ok(suggestions
+ .into_iter()
+ .filter(|suggestion| suggestion.to_lowercase().contains(&query))
+ .map(|suggestion| zed::SlashCommandArgumentCompletion {
+ label: suggestion.to_string(),
+ new_text: suggestion.to_string(),
+ run_command: true,
+ })
+ .collect())
+ }
+
+ fn language_server_command(
+ &mut self,
+ _language_server_id: &zed_extension_api::LanguageServerId,
+ _worktree: &zed_extension_api::Worktree,
+ ) -> Result<zed_extension_api::Command> {
+ Err("Not implemented".into())
+ }
+}
+
+#[derive(serde::Deserialize)]
+struct StreamEvent {
+ id: String,
+ model: String,
+ created: u64,
+ usage: Usage,
+ object: String,
+ choices: Vec<Choice>,
+}
+
+#[derive(serde::Deserialize)]
+struct Usage {
+ prompt_tokens: u32,
+ completion_tokens: u32,
+ total_tokens: u32,
+}
+
+#[derive(serde::Deserialize)]
+struct Choice {
+ index: u32,
+ finish_reason: Option<String>,
+ message: Message,
+ delta: Delta,
+}
+
+#[derive(serde::Deserialize)]
+struct Message {
+ role: String,
+ content: String,
+}
+
+#[derive(serde::Deserialize)]
+struct Delta {
+ role: String,
+ content: String,
+}
+
+zed::register_extension!(Perplexity);
@@ -16,8 +16,8 @@
# --worktree option. It also provides informative output and error handling.
if [ "$1" = "link" ]; then
- # Remove existing link
- rm -f ~/.config/zed/prompt_overrides
+ # Remove existing link (or directory)
+ rm -rf ~/.config/zed/prompt_overrides
if [ "$2" = "--worktree" ]; then
# Check if 'prompts' branch exists, create if not
if ! git show-ref --quiet refs/heads/prompts; then
@@ -30,17 +30,21 @@ if [ "$1" = "link" ]; then
# Create worktree if it doesn't exist
git worktree add ../zed_prompts prompts || git worktree add ../zed_prompts -b prompts
fi
- ln -sf "$(pwd)/../zed_prompts/assets/prompts" ~/.config/zed/prompt_overrides
+ ln -sf "$(realpath "$(pwd)/../zed_prompts/assets/prompts")" ~/.config/zed/prompt_overrides
echo "Linked $(realpath "$(pwd)/../zed_prompts/assets/prompts") to ~/.config/zed/prompt_overrides"
- echo -e "\033[0;31mDon't forget you have it linked, or your prompts will go stale\033[0m"
+ echo -e "\033[0;33mDon't forget you have it linked, or your prompts will go stale\033[0m"
else
ln -sf "$(pwd)/assets/prompts" ~/.config/zed/prompt_overrides
echo "Linked $(pwd)/assets/prompts to ~/.config/zed/prompt_overrides"
fi
elif [ "$1" = "unlink" ]; then
- # Remove symbolic link
- rm ~/.config/zed/prompt_overrides
- echo "Unlinked ~/.config/zed/prompt_overrides"
+ if [ -e ~/.config/zed/prompt_overrides ]; then
+ # Remove symbolic link
+ rm -rf ~/.config/zed/prompt_overrides
+ echo "Unlinked ~/.config/zed/prompt_overrides"
+ else
+ echo -e "\033[33mWarning: No file exists at ~/.config/zed/prompt_overrides\033[0m"
+ fi
else
echo "This script helps you manage prompt overrides for Zed."
echo "You can link this directory to have Zed use the contents of your current repo templates as your active prompts,"