Detailed changes
@@ -16,7 +16,7 @@ jobs:
fi
echo "::set-output name=URL::$URL"
- name: Get content
- uses: 2428392/gh-truncate-string-action@67b1b814955634208b103cff064be3cb1c7a19be # v1.3.0
+ uses: 2428392/gh-truncate-string-action@e6b5885fb83c81ca9a700a91b079baec2133be3e # v1.4.0
id: get-content
with:
stringToTruncate: |
@@ -3,5 +3,10 @@
"label": "clippy",
"command": "./script/clippy",
"args": []
+ },
+ {
+ "label": "cargo run --profile release-fast",
+ "command": "cargo",
+ "args": ["run", "--profile", "release-fast"]
}
]
@@ -358,6 +358,7 @@ dependencies = [
"clock",
"collections",
"command_palette_hooks",
+ "context_servers",
"ctor",
"db",
"editor",
@@ -366,6 +367,7 @@ dependencies = [
"fs",
"futures 0.3.30",
"fuzzy",
+ "globset",
"gpui",
"handlebars",
"heed",
@@ -397,6 +399,7 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"settings",
+ "similar",
"smallvec",
"smol",
"telemetry_events",
@@ -791,9 +794,9 @@ dependencies = [
[[package]]
name = "async-stripe"
-version = "0.37.3"
+version = "0.38.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2f14b5943a52cf051bbbbb68538e93a69d1e291934174121e769f4b181113f5"
+checksum = "97ddaa6999d246ba2c6c84d830a1ba0cd16c9234d58701988b3869f0e5bd732d"
dependencies = [
"chrono",
"futures-util",
@@ -983,9 +986,9 @@ dependencies = [
[[package]]
name = "aws-config"
-version = "1.5.4"
+version = "1.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "caf6cfe2881cb1fcbba9ae946fb9a6480d3b7a714ca84c74925014a89ef3387a"
+checksum = "4e95816a168520d72c0e7680c405a5a8c1fb6a035b4bc4b9d7b0de8e1a941697"
dependencies = [
"aws-credential-types",
"aws-runtime",
@@ -1003,7 +1006,6 @@ dependencies = [
"fastrand 2.1.0",
"hex",
"http 0.2.12",
- "hyper",
"ring",
"time",
"tokio",
@@ -1026,9 +1028,9 @@ dependencies = [
[[package]]
name = "aws-runtime"
-version = "1.3.1"
+version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "87c5f920ffd1e0526ec9e70e50bf444db50b204395a0fa7016bbf9e31ea1698f"
+checksum = "f42c2d4218de4dcd890a109461e2f799a1a2ba3bcd2cde9af88360f5df9266c6"
dependencies = [
"aws-credential-types",
"aws-sigv4",
@@ -1042,6 +1044,7 @@ dependencies = [
"fastrand 2.1.0",
"http 0.2.12",
"http-body 0.4.6",
+ "once_cell",
"percent-encoding",
"pin-project-lite",
"tracing",
@@ -1050,9 +1053,9 @@ dependencies = [
[[package]]
name = "aws-sdk-s3"
-version = "1.42.0"
+version = "1.43.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "558bbcec8db82a1a8af1610afcb3b10d00652d25ad366a0558eecdff2400a1d1"
+checksum = "9ccda7e730ace3cb8bbd4071bc650c6d294364891f9564bd4e43adfc8dea3177"
dependencies = [
"ahash 0.8.11",
"aws-credential-types",
@@ -1085,9 +1088,9 @@ dependencies = [
[[package]]
name = "aws-sdk-sso"
-version = "1.36.0"
+version = "1.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6acca681c53374bf1d9af0e317a41d12a44902ca0f2d1e10e5cb5bb98ed74f35"
+checksum = "1074e818fbe4f9169242d78448b15be8916a79daa38ea1231f2e2e10d993fcd2"
dependencies = [
"aws-credential-types",
"aws-runtime",
@@ -1107,9 +1110,9 @@ dependencies = [
[[package]]
name = "aws-sdk-ssooidc"
-version = "1.37.0"
+version = "1.38.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b79c6bdfe612503a526059c05c9ccccbf6bd9530b003673cb863e547fd7c0c9a"
+checksum = "29755c51e33fa3f678598f64324a169cf4b7d3c4865d2709d4308f53366a92a4"
dependencies = [
"aws-credential-types",
"aws-runtime",
@@ -1129,9 +1132,9 @@ dependencies = [
[[package]]
name = "aws-sdk-sts"
-version = "1.36.0"
+version = "1.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "32e6ecdb2bd756f3b2383e6f0588dc10a4e65f5d551e70a56e0bfe0c884673ce"
+checksum = "6e52dc3fd7dfa6c01a69cf3903e00aa467261639138a05b06cd92314d2c8fb07"
dependencies = [
"aws-credential-types",
"aws-runtime",
@@ -1192,9 +1195,9 @@ dependencies = [
[[package]]
name = "aws-smithy-checksums"
-version = "0.60.11"
+version = "0.60.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "48c4134cf3adaeacff34d588dbe814200357b0c466d730cf1c0d8054384a2de4"
+checksum = "598b1689d001c4d4dc3cb386adb07d37786783aee3ac4b324bcadac116bf3d23"
dependencies = [
"aws-smithy-http",
"aws-smithy-types",
@@ -1291,9 +1294,9 @@ dependencies = [
[[package]]
name = "aws-smithy-runtime-api"
-version = "1.7.1"
+version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "30819352ed0a04ecf6a2f3477e344d2d1ba33d43e0f09ad9047c12e0d923616f"
+checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96"
dependencies = [
"aws-smithy-async",
"aws-smithy-types",
@@ -2206,9 +2209,9 @@ dependencies = [
[[package]]
name = "clap"
-version = "4.5.13"
+version = "4.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0fbb260a053428790f3de475e304ff84cdbc4face759ea7a3e64c1edd938a7fc"
+checksum = "11d8838454fda655dafd3accb2b6e2bea645b9e4078abe84a22ceb947235c5cc"
dependencies = [
"clap_builder",
"clap_derive",
@@ -2216,9 +2219,9 @@ dependencies = [
[[package]]
name = "clap_builder"
-version = "4.5.13"
+version = "4.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "64b17d7ea74e9f833c7dbf2cbe4fb12ff26783eda4782a8975b72f895c9b4d99"
+checksum = "216aec2b177652e3846684cbfe25c9964d18ec45234f0f5da5157b207ed1aab6"
dependencies = [
"anstream",
"anstyle",
@@ -2668,6 +2671,27 @@ dependencies = [
"tiny-keccak",
]
+[[package]]
+name = "context_servers"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "collections",
+ "futures 0.3.30",
+ "gpui",
+ "log",
+ "parking_lot",
+ "postage",
+ "schemars",
+ "serde",
+ "serde_json",
+ "settings",
+ "smol",
+ "url",
+ "util",
+ "workspace",
+]
+
[[package]]
name = "convert_case"
version = "0.4.0"
@@ -2736,9 +2760,9 @@ dependencies = [
[[package]]
name = "core-foundation-sys"
-version = "0.8.6"
+version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f"
+checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b"
[[package]]
name = "core-graphics"
@@ -3141,12 +3165,12 @@ dependencies = [
[[package]]
name = "ctrlc"
-version = "3.4.4"
+version = "3.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "672465ae37dc1bc6380a6547a8883d5dd397b0f1faaad4f265726cc7042a5345"
+checksum = "90eeab0aa92f3f9b4e87f258c72b139c207d251f9cbc1080a0086b86a8870dd3"
dependencies = [
- "nix 0.28.0",
- "windows-sys 0.52.0",
+ "nix 0.29.0",
+ "windows-sys 0.59.0",
]
[[package]]
@@ -5033,9 +5057,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "heed"
-version = "0.20.3"
+version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2bc30da4a93ff8cb98e535d595d6de42731d4719d707bc1c86f579158751a24e"
+checksum = "620033c8c8edfd2f53e6f99a30565eb56a33b42c468e3ad80e21d85fb93bafb0"
dependencies = [
"bitflags 2.6.0",
"byteorder",
@@ -6220,18 +6244,18 @@ dependencies = [
[[package]]
name = "linkme"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ccb76662d78edc9f9bf56360d6919bdacc8b7761227727e5082f128eeb90bbf5"
+checksum = "3c943daedff228392b791b33bba32e75737756e80a613e32e246c6ce9cbab20a"
dependencies = [
"linkme-impl",
]
[[package]]
name = "linkme-impl"
-version = "0.3.27"
+version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f8dccda732e04fa3baf2e17cf835bfe2601c7c2edafd64417c627dabae3a8cda"
+checksum = "cb26336e6dc7cc76e7927d2c9e7e3bb376d7af65a6f56a0b16c47d18a9b1abc5"
dependencies = [
"proc-macro2",
"quote",
@@ -6290,9 +6314,9 @@ dependencies = [
[[package]]
name = "lmdb-master-sys"
-version = "0.2.2"
+version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "57640c190703d5ccf4a86aff4aeb749b2d287a8cb1723c76b51f39d77ab53b24"
+checksum = "1de7e761853c15ca72821d9f928d7bb123ef4c05377c4e7ab69fa1c742f91d24"
dependencies = [
"cc",
"doxygen-rs",
@@ -9533,18 +9557,18 @@ dependencies = [
[[package]]
name = "serde"
-version = "1.0.204"
+version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12"
+checksum = "5665e14a49a4ea1b91029ba7d3bca9f299e1f7cfa194388ccc20f14743e784f2"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
-version = "1.0.204"
+version = "1.0.207"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222"
+checksum = "6aea2634c86b0e8ef2cfdc0c340baede54ec27b1e46febd7f80dffb2aa44a00e"
dependencies = [
"proc-macro2",
"quote",
@@ -9584,9 +9608,9 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.122"
+version = "1.0.125"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da"
+checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed"
dependencies = [
"indexmap 2.3.0",
"itoa",
@@ -9920,6 +9944,13 @@ dependencies = [
"autocfg",
]
+[[package]]
+name = "slash_commands_example"
+version = "0.1.0"
+dependencies = [
+ "zed_extension_api 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
[[package]]
name = "slice-group-by"
version = "0.3.1"
@@ -13783,7 +13814,7 @@ dependencies = [
[[package]]
name = "zed"
-version = "0.149.0"
+version = "0.150.0"
dependencies = [
"activity_indicator",
"anyhow",
@@ -13977,12 +14008,23 @@ dependencies = [
"wit-bindgen",
]
+[[package]]
+name = "zed_extension_api"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "594fd10dd0f2f853eb243e2425e7c95938cef49adb81d9602921d002c5e6d9d9"
+dependencies = [
+ "serde",
+ "serde_json",
+ "wit-bindgen",
+]
+
[[package]]
name = "zed_gleam"
-version = "0.1.3"
+version = "0.2.0"
dependencies = [
"html_to_markdown 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "zed_extension_api 0.1.0",
+ "zed_extension_api 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -14108,9 +14150,9 @@ dependencies = [
[[package]]
name = "zed_zig"
-version = "0.1.5"
+version = "0.2.0"
dependencies = [
- "zed_extension_api 0.0.6",
+ "zed_extension_api 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -19,6 +19,7 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
+ "crates/context_servers",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
@@ -148,6 +149,7 @@ members = [
"extensions/purescript",
"extensions/ruff",
"extensions/ruby",
+ "extensions/slash-commands-example",
"extensions/snippets",
"extensions/svelte",
"extensions/terraform",
@@ -189,6 +191,7 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
+context_servers = { path = "crates/context_servers" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
@@ -461,7 +464,7 @@ which = "6.0.0"
wit-component = "0.201"
[workspace.dependencies.async-stripe]
-version = "0.37"
+version = "0.38"
default-features = false
features = [
"runtime-tokio-hyper-rustls",
@@ -89,9 +89,9 @@
"alt-right": "editor::MoveToNextWordEnd",
"alt-f": "editor::MoveToNextWordEnd",
"cmd-left": "editor::MoveToBeginningOfLine",
- "ctrl-a": ["editor::MoveToBeginningOfLine", { "stop_at_soft_wraps": false }],
+ "ctrl-a": "editor::MoveToBeginningOfLine",
"cmd-right": "editor::MoveToEndOfLine",
- "ctrl-e": ["editor::MoveToEndOfLine", { "stop_at_soft_wraps": false }],
+ "ctrl-e": "editor::MoveToEndOfLine",
"cmd-up": "editor::MoveToBeginning",
"cmd-down": "editor::MoveToEnd",
"shift-up": "editor::SelectUp",
@@ -1,49 +1,423 @@
-Here's a text file that I'm going to ask you to make an edit to.
+You are an expert developer assistant working in an AI-enabled text editor.
+Your task is to rewrite a specific section of the provided document based on a user-provided prompt.
-{{#if language_name}}
-The file is in {{language_name}}.
-{{/if}}
+<guidelines>
+1. Scope: Modify only content within <rewrite_this> tags. Do not alter anything outside these boundaries.
+2. Precision: Make changes strictly necessary to fulfill the given prompt. Preserve all other content as-is.
+3. Seamless integration: Ensure rewritten sections flow naturally with surrounding text and maintain document structure.
+4. Tag exclusion: Never include <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags in the output.
+5. Indentation: Maintain the original indentation level of the file in rewritten sections.
+6. Completeness: Rewrite the entire tagged section, even if only partial changes are needed. Avoid omissions or elisions.
+7. Insertions: Replace <insert_here></insert_here> tags with appropriate content as specified by the prompt.
+8. Code integrity: Respect existing code structure and functionality when making changes.
+9. Consistency: Maintain a uniform style and tone throughout the rewritten text.
+</guidelines>
-You need to rewrite a portion of it.
+<examples>
+<example>
+<input>
+<document>
+use std::cell::Cell;
+use std::collections::HashMap;
+use std::cmp;
-The section you'll need to edit is marked with <rewrite_this></rewrite_this> tags.
+<rewrite_this>
+<insert_here></insert_here>
+</rewrite_this>
+pub struct LruCache<K, V> {
+ /// The maximum number of items the cache can hold.
+ capacity: usize,
+ /// The map storing the cached items.
+ items: HashMap<K, V>,
+}
+// The rest of the implementation...
+</document>
+<prompt>
+doc this
+</prompt>
+</input>
+
+<incorrect_output failure="Over-generation. The text starting with `pub struct AabbTree<T> {` is *after* the rewrite_this tag">
+/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
+///
+/// This structure is used for efficient spatial queries and collision detection.
+/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
+///
+/// # Type Parameters
+///
+/// * `T`: The type of data associated with each node in the tree.
+pub struct AabbTree<T> {
+ root: Option<usize>,
+</incorrect_output>
+<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
+/// Represents an Axis-Aligned Bounding Box (AABB) tree data structure.
+///
+/// This structure is used for efficient spatial queries and collision detection.
+/// It organizes objects in a hierarchical tree structure based on their bounding boxes.
+///
+/// # Type Parameters
+///
+/// * `T`: The type of data associated with each node in the tree.
+</corrected_output>
+</example>
+
+<example>
+<input>
<document>
-{{{document_content}}}
+import math
+
+def calculate_circle_area(radius):
+ """Calculate the area of a circle given its radius."""
+ return math.pi * radius ** 2
+
+<rewrite_this>
+<insert_here></insert_here>
+</rewrite_this>
+class Circle:
+ def __init__(self, radius):
+ self.radius = radius
+
+ def area(self):
+ return math.pi * self.radius ** 2
+
+ def circumference(self):
+ return 2 * math.pi * self.radius
+
+# Usage example
+circle = Circle(5)
+print(f"Area: {circle.area():.2f}")
+print(f"Circumference: {circle.circumference():.2f}")
</document>
+<prompt>
+write docs
+</prompt>
+</input>
-{{#if is_truncated}}
-The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.
-{{/if}}
+<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
+"""
+Represents a circle with methods to calculate its area and circumference.
+
+This class provides a simple way to work with circles in a geometric context.
+It allows for the creation of Circle objects with a specified radius and
+offers methods to compute the circle's area and circumference.
+
+Attributes:
+ radius (float): The radius of the circle.
-Rewrite the section of {{content_type}} in <rewrite_this></rewrite_this> tags based on the following prompt:
+Methods:
+ area(): Calculates and returns the area of the circle.
+ circumference(): Calculates and returns the circumference of the circle.
+"""
+class Circle:
+</incorrect_output>
+<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
+"""
+Represents a circle with methods to calculate its area and circumference.
+This class provides a simple way to work with circles in a geometric context.
+It allows for the creation of Circle objects with a specified radius and
+offers methods to compute the circle's area and circumference.
+
+Attributes:
+ radius (float): The radius of the circle.
+
+Methods:
+ area(): Calculates and returns the area of the circle.
+ circumference(): Calculates and returns the circumference of the circle.
+"""
+</corrected_output>
+</example>
+
+<example>
+<input>
+<document>
+class BankAccount {
+ private balance: number;
+
+ constructor(initialBalance: number) {
+ this.balance = initialBalance;
+ }
+
+ <rewrite_this>
+ <insert_here></insert_here>
+ </rewrite_this>
+ deposit(amount: number): void {
+ if (amount > 0) {
+ this.balance += amount;
+ }
+ }
+
+ withdraw(amount: number): boolean {
+ if (amount > 0 && this.balance >= amount) {
+ this.balance -= amount;
+ return true;
+ }
+ return false;
+ }
+
+ getBalance(): number {
+ return this.balance;
+ }
+}
+
+// Usage
+const account = new BankAccount(1000);
+account.deposit(500);
+console.log(account.getBalance()); // 1500
+account.withdraw(200);
+console.log(account.getBalance()); // 1300
+</document>
<prompt>
-{{{user_prompt}}}
+//
</prompt>
+</input>
-Here's the section to edit based on that prompt again for reference:
+<incorrect_output failure="Over-generation. The text starting with `deposit(amount: number): void {` is *after* the rewrite_this tag">
+ /**
+ * Deposits the specified amount into the bank account.
+ *
+ * @param amount The amount to deposit. Must be a positive number.
+ * @throws Error if the amount is not positive.
+ */
+ deposit(amount: number): void {
+ if (amount > 0) {
+ this.balance += amount;
+ } else {
+ throw new Error("Deposit amount must be positive");
+ }
+ }
+</incorrect_output>
+<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
+ /**
+ * Deposits the specified amount into the bank account.
+ *
+ * @param amount The amount to deposit. Must be a positive number.
+ * @throws Error if the amount is not positive.
+ */
+</corrected_output>
+</example>
+
+<example>
+<input>
+<document>
+use std::collections::VecDeque;
+
+pub struct BinaryTree<T> {
+ root: Option<Node<T>>,
+}
<rewrite_this>
-{{{rewrite_section}}}
+<insert_here></insert_here>
</rewrite_this>
+struct Node<T> {
+ value: T,
+ left: Option<Box<Node<T>>>,
+ right: Option<Box<Node<T>>>,
+}
+</document>
+<prompt>
+derive clone
+</prompt>
+</input>
+
+<incorrect_output failure="Over-generation below the rewrite_this tags. Extra space between derive annotation and struct definition.">
+#[derive(Clone)]
+
+struct Node<T> {
+ value: T,
+ left: Option<Box<Node<T>>>,
+ right: Option<Box<Node<T>>>,
+}
+</incorrect_output>
+
+<incorrect_output failure="Over-generation above the rewrite_this tags">
+pub struct BinaryTree<T> {
+ root: Option<Node<T>>,
+}
+
+#[derive(Clone)]
+</incorrect_output>
+
+<incorrect_output failure="Over-generation below the rewrite_this tags">
+#[derive(Clone)]
+struct Node<T> {
+ value: T,
+ left: Option<Box<Node<T>>>,
+ right: Option<Box<Node<T>>>,
+}
-You'll rewrite this entire section, but you will only make changes within certain subsections.
+impl<T> Node<T> {
+ fn new(value: T) -> Self {
+ Node {
+ value,
+ left: None,
+ right: None,
+ }
+ }
+}
+</incorrect_output>
+<corrected_output improvement="Only includes the new content within the rewrite_this tags">
+#[derive(Clone)]
+</corrected_output>
+</example>
+<example>
+<input>
+<document>
+import math
+
+def calculate_circle_area(radius):
+ """Calculate the area of a circle given its radius."""
+ return math.pi * radius ** 2
+
+<rewrite_this>
+<insert_here></insert_here>
+</rewrite_this>
+class Circle:
+ def __init__(self, radius):
+ self.radius = radius
+
+ def area(self):
+ return math.pi * self.radius ** 2
+
+ def circumference(self):
+ return 2 * math.pi * self.radius
+
+# Usage example
+circle = Circle(5)
+print(f"Area: {circle.area():.2f}")
+print(f"Circumference: {circle.circumference():.2f}")
+</document>
+<prompt>
+add dataclass decorator
+</prompt>
+</input>
+
+<incorrect_output failure="Over-generation. The text starting with `class Circle:` is *after* the rewrite_this tag">
+@dataclass
+class Circle:
+ radius: float
+
+ def __init__(self, radius):
+ self.radius = radius
+
+ def area(self):
+ return math.pi * self.radius ** 2
+</incorrect_output>
+<corrected_output improvement="Generation stops before repeating content after the rewrite_this section">
+@dataclass
+</corrected_output>
+</example>
+
+<example>
+<input>
+<document>
+interface ShoppingCart {
+ items: string[];
+ total: number;
+}
+
+<rewrite_this>
+<insert_here></insert_here>class ShoppingCartManager {
+</rewrite_this>
+ private cart: ShoppingCart;
+
+ constructor() {
+ this.cart = { items: [], total: 0 };
+ }
+
+ addItem(item: string, price: number): void {
+ this.cart.items.push(item);
+ this.cart.total += price;
+ }
+
+ getTotal(): number {
+ return this.cart.total;
+ }
+}
+
+// Usage
+const manager = new ShoppingCartManager();
+manager.addItem("Book", 15.99);
+console.log(manager.getTotal()); // 15.99
+</document>
+<prompt>
+add readonly modifier
+</prompt>
+</input>
+
+<incorrect_output failure="Over-generation. The line starting with ` items: string[];` is *after* the rewrite_this tag">
+readonly interface ShoppingCart {
+ items: string[];
+ total: number;
+}
+
+class ShoppingCartManager {
+ private readonly cart: ShoppingCart;
+
+ constructor() {
+ this.cart = { items: [], total: 0 };
+ }
+</incorrect_output>
+<corrected_output improvement="Only includes the new content within the rewrite_this tags and integrates cleanly into surrounding code">
+readonly interface ShoppingCart {
+</corrected_output>
+</example>
+
+</examples>
+
+With these examples in mind, edit the following file:
+
+<document language="{{ language_name }}">
+{{{ document_content }}}
+</document>
+
+{{#if is_truncated}}
+The provided document has been truncated (potentially mid-line) for brevity.
+{{/if}}
+
+<instructions>
{{#if has_insertion}}
-Insert text anywhere you see it marked with with <insert_here></insert_here> tags. Do not include <insert_here> tags in your output.
+Insert text anywhere you see marked with <insert_here></insert_here> tags. It's CRITICAL that you DO NOT include <insert_here> tags in your output.
{{/if}}
{{#if has_replacement}}
-Edit edit text that you see surrounded with <edit_here></edit_here> tags. Do not include <edit_here> tags in your output.
+Edit text that you see surrounded with <edit_here>...</edit_here> tags. It's CRITICAL that you DO NOT include <edit_here> tags in your output.
{{/if}}
+Make no changes to the rewritten content outside these tags.
+<snippet language="{{ language_name }}" annotated="true">
+{{{ rewrite_section_prefix }}}
<rewrite_this>
-{{{rewrite_section_with_selections}}}
+{{{ rewrite_section_with_edits }}}
</rewrite_this>
+{{{ rewrite_section_suffix }}}
+</snippet>
+
+Rewrite the lines enclosed within the <rewrite_this></rewrite_this> tags in accordance with the provided instructions and the prompt below.
+
+<prompt>
+{{{ user_prompt }}}
+</prompt>
+
+Do not include <insert_here> or <edit_here> annotations in your output. Here is a clean copy of the snippet without annotations for your reference.
-Only make changes that are necessary to fulfill the prompt, leave everything else as-is. All surrounding {{content_type}} will be preserved. Do not output the <rewrite_this></rewrite this> tags or anything outside of them.
+<snippet>
+{{{ rewrite_section_prefix }}}
+{{{ rewrite_section }}}
+{{{ rewrite_section_suffix }}}
+</snippet>
+</instructions>
-Start at the indentation level in the original file in the rewritten {{content_type}}. Don't stop until you've rewritten the entire section, even if you have no more changes to make. Always write out the whole section with no unnecessary elisions.
+<guidelines_reminder>
+1. Focus on necessary changes: Modify only what's required to fulfill the prompt.
+2. Preserve context: Maintain all surrounding content as-is, ensuring the rewritten section seamlessly integrates with the existing document structure and flow.
+3. Exclude annotation tags: Do not output <rewrite_this>, </rewrite_this>, <edit_here>, or <insert_here> tags.
+4. Maintain indentation: Begin at the original file's indentation level.
+5. Complete rewrite: Continue until the entire section is rewritten, even if no further changes are needed.
+6. Avoid elisions: Always write out the full section without unnecessary omissions. NEVER say `// ...` or `// ...existing code` in your output.
+7. Respect content boundaries: Preserve code integrity.
+</guidelines_reminder>
Immediately start with the following format with no remarks:
@@ -1,87 +1,302 @@
+<workflow>
Guide the user through code changes in numbered steps that focus on individual functions, type definitions, etc.
-Surround each distinct step in a <step></step> XML tag.
+Surround each distinct step in a <step></step> XML tag. The user will be performing these steps in a code editor
+named Zed, which is where they will have entered this prompt and will be seeing the response.
-Here's an example of a user-assistant dialog with step numbers and a follow-up question in Rust:
+<instructions>
+- Use the language of the file for code fence blocks unless otherwise specified.
+- Include a code or file action in each step.
+- Only put code in separate steps if it should either go in separate files, or in different (non-contiguous) places in the same file.
+- Provide error handling and input validation where appropriate.
+- Adapt explanations based on the user's perceived level of expertise.
+- Include comments in code examples to enhance understanding.
+- Consider more complex programming scenarios when relevant.
+- Avoid using the terminal to perform filesystem operations such as creating, deleting, renaming, and editing files. Assume the user prefers to use Zed for these operations, unless the prompt suggests that the user is expecting terminal commands.
+- When creating a new file, Zed will automatically create any necessary directories in that file's path. So when a step will create a new file in Zed, don't mention needing to create the directory that it will go in.
+- Assume the user is performing these actions in the Zed code editor, so avoid redundancies like "In your code editor, ..." or "In Zed, ..."
+- Be concise without leaving out important information.
+</instructions>
-1. User: I want to create a User struct with name and age fields, and a method to print the user's information. Can you help me with that?
-A: Certainly! Let's create a User struct with the specified fields and a method to print the information:
-<step>Define the User struct
+Examples of user-assistant dialogs with step numbers and user follow-up requests:
+
+<example language="rust">
+<turn number="1">
+<user>Create a Library struct with a vector of Book structs. Include methods to add books and find books by author. Use generics for the book ID.</user>
+<assistant>Certainly! Let's create a Library struct with a vector of Book structs, including methods to add books and find books by author, using generics for the book ID:
+
+<step>Define the Book and Library structs
```rust
-struct User {
- name: String,
- age: u32,
+use std::fmt::Display;
+
+struct Book<T>
+where
+ T: Display + Clone
+{
+ id: T,
+ title: String,
+ author: String,
+}
+
+struct Library<T>
+where
+ T: Display + Clone
+{
+ books: Vec<Book<T>>,
}
```
</step>
-<step>Implement methods for the User struct
+<step>Implement methods for the Library struct
```rust
-impl User {
- fn new(name: String, age: u32) -> Self {
- User { name, age }
+impl<T> Library<T>
+where
+ T: Display + Clone
+{
+ fn new() -> Self {
+ Library { books: Vec::new() }
+ }
+
+ fn add_book(&mut self, id: T, title: String, author: String) {
+ let book = Book { id, title, author };
+ self.books.push(book);
}
- fn print_info(&self) {
- println!("Name: {}, Age: {}", self.name, self.age);
+
+ fn find_books_by_author(&self, author: &str) -> Vec<&Book<T>> {
+ self.books
+ .iter()
+ .filter(|book| book.author.to_lowercase() == author.to_lowercase())
+ .collect()
}
}
```
</step>
-This code defines a User struct and implements two methods:
-- `new`: a constructor to create a new User
-- `print_info`: a method to print the user's information
+This code defines:
+1. A generic Book struct with an ID of type T, title, and author
+2. A Library struct containing a vector of Books
+3. Methods to create a new Library, add books, and find books by author
+The use of generics allows flexibility in the type of ID used for books.</assistant>
+</turn>
-2. User: Great! Now, how can I create a User instance and call the print_info method?
-A: You can create a User instance and call the print_info method like this:
+<turn number="2">
+<user>How can I use this Library struct in a main function?</user>
+<assistant>Here's how you can use the Library struct in a main function:
-<step>Create a User instance and call print_info
+<step>Create a main function to demonstrate Library usage
```rust
fn main() {
- let user = User::new(String::from("Alice"), 30);
- user.print_info();
+ // Create a new Library with string IDs
+ let mut library = Library::new();
+
+ // Add some books
+ library.add_book(String::from("001"), String::from("The Rust Programming Language"), String::from("Steve Klabnik"));
+ library.add_book(String::from("002"), String::from("Programming Rust"), String::from("Jim Blandy"));
+ library.add_book(String::from("003"), String::from("Rust in Action"), String::from("Tim McNamara"));
+
+ // Find books by author
+ let author = "Steve Klabnik";
+ let books = library.find_books_by_author(author);
+
+ println!("Books by {}:", author);
+ for book in books {
+ println!("ID: {}, Title: {}", book.id, book.title);
+ }
}
```
</step>
-This will output: `Name: Alice, Age: 30`
+This main function demonstrates:
+1. Creating a new Library instance
+2. Adding books with string IDs
+3. Finding and displaying books by a specific author</assistant>
+</turn>
-3. User: Can you add an email field to the User struct and update the print_info method to include it?
-A: Certainly! Let's modify the User struct and update the print_info method:
+<turn number="3">
+<user>Can you add a method to remove a book by ID and handle potential errors?</user>
+<assistant>Certainly! Let's add a method to remove a book by ID and handle potential errors:
-<step>Add email field to User struct
+<step>Add a remove_book method to the Library struct
```rust
-struct User {
- name: String,
- age: u32,
- email: String,
+use std::fmt::Display;
+use std::error::Error;
+use std::fmt;
+
+#[derive(Debug)]
+struct BookNotFoundError;
+
+impl Error for BookNotFoundError {}
+
+impl fmt::Display for BookNotFoundError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "Book not found")
+ }
+}
+
+impl<T> Library<T>
+where
+ T: Display + Clone + PartialEq
+{
+ // ... (previous methods remain the same)
+
+ fn remove_book(&mut self, id: &T) -> Result<Book<T>, BookNotFoundError> {
+ if let Some(index) = self.books.iter().position(|book| &book.id == id) {
+ Ok(self.books.remove(index))
+ } else {
+ Err(BookNotFoundError)
+ }
+ }
}
```
</step>
-
-<step>Update the new method and print_info method
+<step>Update the main function to demonstrate book removal
```rust
-impl User {
- fn new(name: String, age: u32, email: String) -> Self {
- User { name, age, email }
+fn main() {
+ // ... (previous code remains the same)
+
+ // Attempt to remove a book
+ let id_to_remove = String::from("002");
+ match library.remove_book(&id_to_remove) {
+ Ok(removed_book) => println!("Removed book: {} by {}", removed_book.title, removed_book.author),
+ Err(e) => println!("Error removing book: {}", e),
}
- fn print_info(&self) {
- println!("Name: {}, Age: {}, Email: {}", self.name, self.age, self.email);
+
+ // Attempt to remove a non-existent book
+ let non_existent_id = String::from("999");
+ match library.remove_book(&non_existent_id) {
+ Ok(_) => println!("Book removed successfully"),
+ Err(e) => println!("Error removing book: {}", e),
+ }
+
+ // Display remaining books
+ println!("\nRemaining books:");
+ for book in &library.books {
+ println!("ID: {}, Title: {} by {}", book.id, book.title, book.author);
}
}
```
</step>
-<step>Update the main function
-```rust
-fn main() {
- let user = User::new(String::from("Alice"), 30, String::from("alice@example.com"));
- user.print_info();
+This update includes:
+1. A new remove_book method that returns a Result type
+2. A custom BookNotFoundError for error handling
+3. Demonstration of successful book removal and error handling for non-existent books
+4. Display of remaining books after removal attempts
+
+The remove_book method uses the PartialEq trait to compare book IDs, allowing for flexible ID types while maintaining type safety.</assistant>
+</turn>
+</example>
+
+<example language="go">
+<turn number="1">
+<user>Create a Person struct with name, age, and email fields, and a method to greet. Include error handling for age.</user>
+<assistant>Certainly! Let's create a Person struct with the specified fields and a method to greet in Go:
+
+<step>Define the Person struct and NewPerson function with error handling
+```go
+type Person struct {
+ Name string
+ Age int
+ Email string
+}
+
+func NewPerson(name string, age int, email string) (*Person, error) {
+ if age < 0 || age > 150 {
+ return nil, fmt.Errorf("invalid age: %d", age)
+ }
+ return &Person{Name: name, Age: age, Email: email}, nil
+}
+```
+</step>
+<step>Implement the Greet method for the Person struct
+```go
+func (p *Person) Greet() string {
+ return fmt.Sprintf("Hello, I'm %s, %d years old. You can reach me at %s.", p.Name, p.Age, p.Email)
+}
+```
+</step>
+
+This code defines a Person struct, a NewPerson function with age validation, and a Greet method.</assistant>
+</turn>
+
+<turn number="2">
+<user>How can I use these in a main function?</user>
+<assistant>Here's how to use the Person struct and its methods in a main function:
+
+<step>Create a main function with error handling
+```go
+func main() {
+ // Create a new Person instance
+ person, err := NewPerson("Alice", 30, "alice@example.com")
+ if err != nil {
+ fmt.Println("Error creating person:", err)
+ return
+ }
+
+ // Use the Greet method
+ fmt.Println(person.Greet())
+
+ // Demonstrate error handling with invalid age
+ invalidPerson, err := NewPerson("Bob", 200, "bob@example.com")
+ if err != nil {
+ fmt.Println("Error creating person with invalid age:", err)
+ }
+}
+```
+</step>
+
+This main function demonstrates:
+1. Creating a valid Person instance
+2. Using the Greet method
+3. Handling errors when creating a Person with an invalid age</assistant>
+</turn>
+
+<turn number="3">
+<user>Can you add a method to update the email with validation?</user>
+<assistant>Certainly! Let's add a method to update the email with validation:
+
+<step>Add UpdateEmail method with validation
+```go
+import (
+ "fmt"
+ "regexp"
+)
+
+func (p *Person) UpdateEmail(newEmail string) error {
+ // Simple email validation using regex
+ emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
+ if !emailRegex.MatchString(newEmail) {
+ return fmt.Errorf("invalid email format: %s", newEmail)
+ }
+ p.Email = newEmail
+ return nil
+}
+```
+</step>
+<step>Update the main function to demonstrate email update
+```go
+func main() {
+ // ... (previous code remains the same)
+
+ // Demonstrate email update
+ err = person.UpdateEmail("alice.new@example.com")
+ if err != nil {
+ fmt.Println("Error updating email:", err)
+ } else {
+ fmt.Println("Updated greeting:", person.Greet())
+ }
+
+ // Demonstrate email update with invalid email
+ err = person.UpdateEmail("invalid-email")
+ if err != nil {
+ fmt.Println("Error updating email:", err)
+ }
}
```
</step>
-This will now output: `Name: Alice, Age: 30, Email: alice@example.com`
-The changes include:
-1. Adding the `email` field to the User struct
-2. Updating the `new` method to accept an email parameter
-3. Modifying the `print_info` method to include the email
-4. Updating the main function to provide an email when creating a User instance
+This update includes:
+1. An UpdateEmail method with email format validation
+2. Demonstration of successful email update in the main function
+3. Handling of invalid email update attempt</assistant>
+</turn>
+</example>
+</workflow>
@@ -1,22 +1,27 @@
-Your task is to map a step from the conversation above to suggestions on symbols inside the provided source files.
+<overview>
+Your task is to map a step from a workflow to locations in source code where code needs to be changed to fulfill that step.
+Given a workflow containing background context plus a series of <step> tags, you will resolve *one* of these step tags to resolve to one or more locations in the code.
+With each location, you will produce a brief, one-line description of the changes to be made.
-Guidelines:
+<guidelines>
- There's no need to describe *what* to do, just *where* to do it.
+- Only reference locations that actually exist (unless you're creating a file).
- If creating a file, assume any subsequent updates are included at the time of creation.
-- Don't create and then update a file.
-- We'll create it in one shot.
+- Don't create and then update a file. Always create new files in shot.
- Prefer updating symbols lower in the syntax tree if possible.
- Never include suggestions on a parent symbol and one of its children in the same suggestions block.
- Never nest an operation with another operation or include CDATA or other content. All suggestions are leaf nodes.
-- Include a description attribute for each operation with a brief, one-line description of the change to perform.
- Descriptions are required for all suggestions except delete.
- When generating multiple suggestions, ensure the descriptions are specific to each individual operation.
- Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
- Don't generate multiple suggestions at the same location. Instead, combine them together in a single operation with a succinct combined description.
+</guidelines>
+</overview>
-Example 1:
-
-User:
+<examples>
+<example>
+<workflow_context>
+<message role="user">
```rs src/rectangle.rs
struct Rectangle {
width: f64,
@@ -30,12 +35,21 @@ impl Rectangle {
}
```
+We need to add methods to calculate the area and perimeter of the rectangle. Can you help with that?
+</message>
+<message role="assistant">
+Sure, I can help with that!
+
<step>Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct</step>
<step>Implement the 'Display' trait for the Rectangle struct</step>
+</message>
+</workflow_context>
-What are the suggestions for the step: <step>Add a new method 'calculate_area' to the Rectangle struct</step>
+<step_to_resolve>
+Add new methods 'calculate_area' and 'calculate_perimeter' to the Rectangle struct
+</step_to_resolve>
-A (wrong):
+<incorrect_output reason="NEVER append multiple children at the same location.">
{
"title": "Add Rectangle methods",
"suggestions": [
@@ -53,10 +67,9 @@ A (wrong):
}
]
}
+</incorrect_output>
-This demonstrates what NOT to do. NEVER append multiple children at the same location.
-
-A (corrected):
+<correct_output>
{
"title": "Add Rectangle methods",
"suggestions": [
@@ -68,11 +81,13 @@ A (corrected):
}
]
}
+</correct_output>
-User:
-What are the suggestions for the step: <step>Implement the 'Display' trait for the Rectangle struct</step>
+<step_to_resolve>
+Implement the 'Display' trait for the Rectangle struct
+</step_to_resolve>
-A:
+<output>
{
"title": "Implement Display for Rectangle",
"suggestions": [
@@ -84,10 +99,11 @@ A:
}
]
}
+</output>
-Example 2:
-
-User:
+<example>
+<workflow_context>
+<message role="user">
```rs src/user.rs
struct User {
pub name: String,
@@ -105,13 +121,19 @@ impl User {
}
}
```
-
+</message>
+<message role="assistant">
+Certainly!
<step>Update the 'print_info' method to use formatted output</step>
<step>Remove the 'email' field from the User struct</step>
+</message>
+</workflow_context>
-What are the suggestions for the step: <step>Update the 'print_info' method to use formatted output</step>
+<step_to_resolve>
+Update the 'print_info' method to use formatted output
+</step_to_resolve>
-A:
+<output>
{
"title": "Use formatted output",
"suggestions": [
@@ -123,11 +145,13 @@ A:
}
]
}
+</output>
-User:
-What are the suggestions for the step: <step>Remove the 'email' field from the User struct</step>
+<step_to_resolve>
+Remove the 'email' field from the User struct
+</step_to_resolve>
-A:
+<output>
{
"title": "Remove email field",
"suggestions": [
@@ -138,10 +162,12 @@ A:
}
]
}
+</output>
+</example>
-Example 3:
-
-User:
+<example>
+<workflow_context>
+<message role="user">
```rs src/vehicle.rs
struct Vehicle {
make: String,
@@ -159,13 +185,18 @@ impl Vehicle {
}
}
```
-
+</message>
+<message role="assistant">
<step>Add a 'use std::fmt;' statement at the beginning of the file</step>
<step>Add a new method 'start_engine' in the Vehicle impl block</step>
+</message>
+</workflow_context>
-What are the suggestions for the step: <step>Add a 'use std::fmt;' statement at the beginning of the file</step>
+<step_to_resolve>
+Add a 'use std::fmt;' statement at the beginning of the file
+</step_to_resolve>
-A:
+<output>
{
"title": "Add use std::fmt statement",
"suggestions": [
@@ -176,11 +207,13 @@ A:
}
]
}
+</output>
-User:
-What are the suggestions for the step: <step>Add a new method 'start_engine' in the Vehicle impl block</step>
+<step_to_resolve>
+Add a new method 'start_engine' in the Vehicle impl block
+</step_to_resolve>
-A:
+<output>
{
"title": "Add start_engine method",
"suggestions": [
@@ -192,10 +225,12 @@ A:
}
]
}
+</output>
+</example>
-Example 4:
-
-User:
+<example>
+<workflow_context>
+<message role="user">
```rs src/employee.rs
struct Employee {
name: String,
@@ -219,12 +254,18 @@ impl Employee {
}
}
```
-
+</message>
+<message role="assistant">
<step>Make salary an f32</step>
+<step>Remove the 'department' field and update the 'print_details' method</step>
+</message>
+</workflow_context>
-What are the suggestions for the step: <step>Make salary an f32</step>
+<step_to_resolve>
+Make salary an f32
+</step_to_resolve>
-A (wrong):
+<incorrect_output reason="NEVER include suggestions on a parent symbol and one of its children in the same suggestions block.">
{
"title": "Change salary to f32",
"suggestions": [
@@ -242,10 +283,9 @@ A (wrong):
}
]
}
+</incorrect_output>
-This example demonstrates what not to do. `struct Employee salary` is a child of `struct Employee`.
-
-A (corrected):
+<correct_output>
{
"title": "Change salary to f32",
"suggestions": [
@@ -257,11 +297,13 @@ A (corrected):
}
]
}
+</correct_output>
-User:
-What are the correct suggestions for the step: <step>Remove the 'department' field and update the 'print_details' method</step>
+<step_to_resolve>
+Remove the 'department' field and update the 'print_details' method
+</step_to_resolve>
-A:
+<output>
{
"title": "Remove department",
"suggestions": [
@@ -278,10 +320,12 @@ A:
}
]
}
+</output>
+</example>
-Example 5:
-
-User:
+<example>
+<workflow_context>
+<message role="user">
```rs src/game.rs
struct Player {
name: String,
@@ -305,10 +349,17 @@ impl Game {
}
}
```
-
+</message>
+<message role="assistant">
<step>Add a 'level' field to Player and update the 'new' method</step>
+</message>
+</workflow_context>
+
+<step_to_resolve>
+Add a 'level' field to Player and update the 'new' method
+</step_to_resolve>
-A:
+<output>
{
"title": "Add level field to Player",
"suggestions": [
@@ -326,10 +377,12 @@ A:
}
]
}
+</output>
+</example>
-Example 6:
-
-User:
+<example>
+<workflow_context>
+<message role="user">
```rs src/config.rs
use std::collections::HashMap;
@@ -343,10 +396,17 @@ impl Config {
}
}
```
-
+</message>
+<message role="assistant">
<step>Add a 'load_from_file' method to Config and import necessary modules</step>
+</message>
+</workflow_context>
+
+<step_to_resolve>
+Add a 'load_from_file' method to Config and import necessary modules
+</step_to_resolve>
-A:
+<output>
{
"title": "Add load_from_file method",
"suggestions": [
@@ -363,10 +423,12 @@ A:
}
]
}
+</output>
+</example>
-Example 7:
-
-User:
+<example>
+<workflow_context>
+<message role="user">
```rs src/database.rs
pub(crate) struct Database {
connection: Connection,
@@ -383,10 +445,17 @@ impl Database {
}
}
```
-
+</message>
+<message role="assistant">
<step>Add error handling to the 'query' method and create a custom error type</step>
+</message>
+</workflow_context>
+
+<step_to_resolve>
+Add error handling to the 'query' method and create a custom error type
+</step_to_resolve>
-A:
+<output>
{
"title": "Add error handling to query",
"suggestions": [
@@ -409,5 +478,16 @@ A:
}
]
}
+</output>
+</example>
+</examples>
Now generate the suggestions for the following step:
+
+<workflow_context>
+{{{workflow_context}}}
+</workflow_context>
+
+<step_to_resolve>
+{{{step_to_resolve}}}
+</step_to_resolve>
@@ -1010,5 +1010,16 @@
// ]
// }
// ]
- "ssh_connections": null
+ "ssh_connections": null,
+ // Configures the Context Server Protocol binaries
+ //
+ // Examples:
+ // {
+ // "id": "server-1",
+ // "executable": "/path",
+ // "args": ['arg1", "args2"]
+ // }
+ "experimental.context_servers": {
+ "servers": []
+ }
}
@@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
-use std::str::FromStr;
use std::time::Duration;
+use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@@ -14,6 +14,14 @@ pub use supported_countries::*;
pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct AnthropicModelCacheConfiguration {
+ pub min_total_token: usize,
+ pub should_speculate: bool,
+ pub max_cache_anchors: usize,
+}
+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
@@ -32,6 +40,9 @@ pub enum Model {
max_tokens: usize,
/// Override this model with a different Anthropic model for tool calls.
tool_override: Option<String>,
+ /// Indicates whether this custom model supports caching.
+ cache_configuration: Option<AnthropicModelCacheConfiguration>,
+ max_output_tokens: Option<u32>,
},
}
@@ -70,6 +81,21 @@ impl Model {
}
}
+ pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
+ match self {
+ Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
+ min_total_token: 2_048,
+ should_speculate: true,
+ max_cache_anchors: 4,
+ }),
+ Self::Custom {
+ cache_configuration,
+ ..
+ } => cache_configuration.clone(),
+ _ => None,
+ }
+ }
+
pub fn max_token_count(&self) -> usize {
match self {
Self::Claude3_5Sonnet
@@ -80,6 +106,16 @@ impl Model {
}
}
+ pub fn max_output_tokens(&self) -> u32 {
+ match self {
+ Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
+ Self::Claude3_5Sonnet => 8_192,
+ Self::Custom {
+ max_output_tokens, ..
+ } => max_output_tokens.unwrap_or(4_096),
+ }
+ }
+
pub fn tool_model_id(&self) -> &str {
if let Self::Custom {
tool_override: Some(tool_override),
@@ -104,7 +140,10 @@ pub async fn complete(
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
- .header("Anthropic-Beta", "tools-2024-04-04")
+ .header(
+ "Anthropic-Beta",
+ "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
+ )
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
@@ -161,7 +200,10 @@ pub async fn stream_completion(
.method(Method::POST)
.uri(uri)
.header("Anthropic-Version", "2023-06-01")
- .header("Anthropic-Beta", "tools-2024-04-04")
+ .header(
+ "Anthropic-Beta",
+ "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
+ )
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
if let Some(low_speed_timeout) = low_speed_timeout {
@@ -226,7 +268,7 @@ pub fn extract_text_from_events(
match response {
Ok(response) => match response {
Event::ContentBlockStart { content_block, .. } => match content_block {
- Content::Text { text } => Some(Ok(text)),
+ Content::Text { text, .. } => Some(Ok(text)),
_ => None,
},
Event::ContentBlockDelta { delta, .. } => match delta {
@@ -241,13 +283,69 @@ pub fn extract_text_from_events(
})
}
+pub async fn extract_tool_args_from_events(
+ tool_name: String,
+ mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
+) -> Result<impl Send + Stream<Item = Result<String>>> {
+ let mut tool_use_index = None;
+ while let Some(event) = events.next().await {
+ if let Event::ContentBlockStart {
+ index,
+ content_block,
+ } = event?
+ {
+ if let Content::ToolUse { name, .. } = content_block {
+ if name == tool_name {
+ tool_use_index = Some(index);
+ break;
+ }
+ }
+ }
+ }
+
+ let Some(tool_use_index) = tool_use_index else {
+ return Err(anyhow!("tool not used"));
+ };
+
+ Ok(events.filter_map(move |event| {
+ let result = match event {
+ Err(error) => Some(Err(error)),
+ Ok(Event::ContentBlockDelta { index, delta }) => match delta {
+ ContentDelta::TextDelta { .. } => None,
+ ContentDelta::InputJsonDelta { partial_json } => {
+ if index == tool_use_index {
+ Some(Ok(partial_json))
+ } else {
+ None
+ }
+ }
+ },
+ _ => None,
+ };
+
+ async move { result }
+ }))
+}
+
+#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
+#[serde(rename_all = "lowercase")]
+pub enum CacheControlType {
+ Ephemeral,
+}
+
+#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
+pub struct CacheControl {
+ #[serde(rename = "type")]
+ pub cache_type: CacheControlType,
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: Vec<Content>,
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
@@ -258,19 +356,31 @@ pub enum Role {
#[serde(tag = "type")]
pub enum Content {
#[serde(rename = "text")]
- Text { text: String },
+ Text {
+ text: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ cache_control: Option<CacheControl>,
+ },
#[serde(rename = "image")]
- Image { source: ImageSource },
+ Image {
+ source: ImageSource,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ cache_control: Option<CacheControl>,
+ },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ cache_control: Option<CacheControl>,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ cache_control: Option<CacheControl>,
},
}
@@ -33,11 +33,13 @@ clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
db.workspace = true
+context_servers.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
futures.workspace = true
fuzzy.workspace = true
+globset.workspace = true
gpui.workspace = true
handlebars.workspace = true
heed.workspace = true
@@ -67,6 +69,7 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smallvec.workspace = true
+similar.workspace = true
smol.workspace = true
telemetry_events.workspace = true
terminal.workspace = true
@@ -3,7 +3,6 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod context;
-pub(crate) mod context_inspector;
pub mod context_store;
mod inline_assistant;
mod model_selector;
@@ -13,6 +12,7 @@ mod slash_command;
pub mod slash_command_settings;
mod streaming_diff;
mod terminal_inline_assistant;
+mod workflow;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
use assistant_settings::AssistantSettings;
@@ -20,9 +20,11 @@ use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
pub use context::*;
+use context_servers::ContextServerRegistry;
pub use context_store::*;
use feature_flags::FeatureFlagAppExt;
use fs::Fs;
+use gpui::Context as _;
use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
@@ -36,13 +38,14 @@ use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsStore};
use slash_command::{
- default_command, diagnostics_command, docs_command, fetch_command, file_command, now_command,
- project_command, prompt_command, search_command, symbols_command, tab_command,
- terminal_command, workflow_command,
+ context_server_command, default_command, diagnostics_command, docs_command, fetch_command,
+ file_command, now_command, project_command, prompt_command, search_command, symbols_command,
+ tab_command, terminal_command, workflow_command,
};
use std::sync::Arc;
pub(crate) use streaming_diff::*;
use util::ResultExt;
+pub use workflow::*;
use crate::slash_command_settings::SlashCommandSettings;
@@ -61,7 +64,6 @@ actions!(
DeployPromptLibrary,
ConfirmCommand,
ToggleModelSelector,
- DebugWorkflowSteps
]
);
@@ -219,6 +221,7 @@ pub fn init(
init_language_model_settings(cx);
assistant_slash_command::init(cx);
assistant_panel::init(cx);
+ context_servers::init(cx);
let prompt_builder = prompts::PromptBuilder::new(Some(PromptOverrideContext {
dev_mode,
@@ -259,9 +262,69 @@ pub fn init(
})
.detach();
+ register_context_server_handlers(cx);
+
prompt_builder
}
+fn register_context_server_handlers(cx: &mut AppContext) {
+ cx.subscribe(
+ &context_servers::manager::ContextServerManager::global(cx),
+ |manager, event, cx| match event {
+ context_servers::manager::Event::ServerStarted { server_id } => {
+ cx.update_model(
+ &manager,
+ |manager: &mut context_servers::manager::ContextServerManager, cx| {
+ let slash_command_registry = SlashCommandRegistry::global(cx);
+ let context_server_registry = ContextServerRegistry::global(cx);
+ if let Some(server) = manager.get_server(server_id) {
+ cx.spawn(|_, _| async move {
+ let Some(protocol) = server.client.read().clone() else {
+ return;
+ };
+
+ if let Some(prompts) = protocol.list_prompts().await.log_err() {
+ for prompt in prompts
+ .into_iter()
+ .filter(context_server_command::acceptable_prompt)
+ {
+ log::info!(
+ "registering context server command: {:?}",
+ prompt.name
+ );
+ context_server_registry.register_command(
+ server.id.clone(),
+ prompt.name.as_str(),
+ );
+ slash_command_registry.register_command(
+ context_server_command::ContextServerSlashCommand::new(
+ &server, prompt,
+ ),
+ true,
+ );
+ }
+ }
+ })
+ .detach();
+ }
+ },
+ );
+ }
+ context_servers::manager::Event::ServerStopped { server_id } => {
+ let slash_command_registry = SlashCommandRegistry::global(cx);
+ let context_server_registry = ContextServerRegistry::global(cx);
+ if let Some(commands) = context_server_registry.get_commands(server_id) {
+ for command_name in commands {
+ slash_command_registry.unregister_command_by_name(&command_name);
+ context_server_registry.unregister_command(&server_id, &command_name);
+ }
+ }
+ }
+ },
+ )
+ .detach();
+}
+
fn init_language_model_settings(cx: &mut AppContext) {
update_active_language_model_from_settings(cx);
@@ -1,6 +1,5 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings},
- context_inspector::ContextInspector,
humanize_token_count,
prompt_library::open_prompt_library,
prompts::PromptBuilder,
@@ -12,12 +11,12 @@ use crate::{
},
terminal_inline_assistant::TerminalInlineAssistant,
Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
- DebugWorkflowSteps, DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId,
- InlineAssistant, InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand,
- PendingSlashCommandStatus, QuoteSelection, RemoteContextMetadata, ResolvedWorkflowStep,
- SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
+ DeployHistory, DeployPromptLibrary, InlineAssist, InlineAssistId, InlineAssistant,
+ InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
+ QuoteSelection, RemoteContextMetadata, SavedContextMetadata, Split, ToggleFocus,
+ ToggleModelSelector, WorkflowStepResolution, WorkflowStepView,
};
-use crate::{ContextStoreEvent, ShowConfiguration};
+use crate::{ContextStoreEvent, ModelPickerDelegate, ShowConfiguration};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use client::{proto, Client, Status};
@@ -36,10 +35,10 @@ use fs::Fs;
use gpui::{
canvas, div, img, percentage, point, pulsating_between, size, Action, Animation, AnimationExt,
AnyElement, AnyView, AppContext, AsyncWindowContext, ClipboardEntry, ClipboardItem,
- Context as _, DismissEvent, Empty, Entity, EntityId, EventEmitter, FocusHandle, FocusableView,
- FontWeight, InteractiveElement, IntoElement, Model, ParentElement, Pixels, ReadGlobal, Render,
- RenderImage, SharedString, Size, StatefulInteractiveElement, Styled, Subscription, Task,
- Transformation, UpdateGlobal, View, VisualContext, WeakView, WindowContext,
+ Context as _, CursorStyle, DismissEvent, Empty, Entity, EntityId, EventEmitter, FocusHandle,
+ FocusableView, FontWeight, InteractiveElement, IntoElement, Model, ParentElement, Pixels,
+ ReadGlobal, Render, RenderImage, SharedString, Size, StatefulInteractiveElement, Styled,
+ Subscription, Task, Transformation, UpdateGlobal, View, VisualContext, WeakView, WindowContext,
};
use indexed_docs::IndexedDocsStore;
use language::{
@@ -57,15 +56,14 @@ use settings::{update_settings_file, Settings};
use smol::stream::StreamExt;
use std::{
borrow::Cow,
- cmp::{self, Ordering},
+ cmp,
fmt::Write,
- ops::Range,
+ ops::{DerefMut, Range},
path::PathBuf,
sync::Arc,
time::Duration,
};
use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
-use text::OffsetRangeExt;
use ui::TintColor;
use ui::{
prelude::*,
@@ -77,7 +75,6 @@ use util::ResultExt;
use workspace::{
dock::{DockPosition, Panel, PanelEvent},
item::{self, FollowableItem, Item, ItemHandle},
- notifications::NotifyTaskExt,
pane::{self, SaveIntent},
searchable::{SearchEvent, SearchableItem},
Pane, Save, ToggleZoom, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
@@ -145,7 +142,7 @@ pub struct AssistantPanel {
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>,
- model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
+ model_selector_menu_handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>,
model_summary_editor: View<Editor>,
authenticate_provider_task: Option<(LanguageModelProviderId, Task<()>)>,
configuration_subscription: Option<Subscription>,
@@ -404,56 +401,13 @@ impl AssistantPanel {
} else {
"Zoom In"
};
- let weak_pane = cx.view().downgrade();
let menu = ContextMenu::build(cx, |menu, cx| {
- let menu = menu
- .context(pane.focus_handle(cx))
+ menu.context(pane.focus_handle(cx))
.action("New Context", Box::new(NewFile))
.action("History", Box::new(DeployHistory))
.action("Prompt Library", Box::new(DeployPromptLibrary))
.action("Configure", Box::new(ShowConfiguration))
- .action(zoom_label, Box::new(ToggleZoom));
-
- if let Some(editor) = pane
- .active_item()
- .and_then(|e| e.downcast::<ContextEditor>())
- {
- let is_enabled = editor.read(cx).debug_inspector.is_some();
- menu.separator().toggleable_entry(
- "Debug Workflows",
- is_enabled,
- IconPosition::End,
- None,
- move |cx| {
- weak_pane
- .update(cx, |this, cx| {
- if let Some(context_editor) =
- this.active_item().and_then(|item| {
- item.downcast::<ContextEditor>()
- })
- {
- context_editor.update(cx, |this, cx| {
- if let Some(mut state) =
- this.debug_inspector.take()
- {
- state.deactivate(cx);
- } else {
- this.debug_inspector = Some(
- ContextInspector::new(
- this.editor.clone(),
- this.context.clone(),
- ),
- );
- }
- })
- }
- })
- .ok();
- },
- )
- } else {
- menu
- }
+ .action(zoom_label, Box::new(ToggleZoom))
});
cx.subscribe(&menu, |pane, _, _: &DismissEvent, _| {
pane.new_item_menu = None;
@@ -1380,15 +1334,17 @@ struct WorkflowStep {
range: Range<language::Anchor>,
header_block_id: CustomBlockId,
footer_block_id: CustomBlockId,
- resolved_step: Option<Result<ResolvedWorkflowStep, Arc<anyhow::Error>>>,
+ resolved_step: Option<Result<WorkflowStepResolution, Arc<anyhow::Error>>>,
assist: Option<WorkflowAssist>,
}
impl WorkflowStep {
fn status(&self, cx: &AppContext) -> WorkflowStepStatus {
match self.resolved_step.as_ref() {
- Some(Ok(_)) => {
- if let Some(assist) = self.assist.as_ref() {
+ Some(Ok(step)) => {
+ if step.suggestion_groups.is_empty() {
+ WorkflowStepStatus::Empty
+ } else if let Some(assist) = self.assist.as_ref() {
let assistant = InlineAssistant::global(cx);
if assist
.assist_ids
@@ -1424,6 +1380,7 @@ impl WorkflowStep {
enum WorkflowStepStatus {
Resolving,
Error(Arc<anyhow::Error>),
+ Empty,
Idle,
Pending,
Done,
@@ -1435,6 +1392,45 @@ impl WorkflowStepStatus {
matches!(self, Self::Confirmed)
}
+ fn render_workflow_step_error(
+ id: EntityId,
+ editor: WeakView<ContextEditor>,
+ step_range: Range<language::Anchor>,
+ error: String,
+ ) -> AnyElement {
+ h_flex()
+ .gap_2()
+ .child(
+ div()
+ .id("step-resolution-failure")
+ .child(
+ Label::new("Step Resolution Failed")
+ .size(LabelSize::Small)
+ .color(Color::Error),
+ )
+ .tooltip(move |cx| Tooltip::text(error.clone(), cx)),
+ )
+ .child(
+ Button::new(("transform", id), "Retry")
+ .icon(IconName::Update)
+ .icon_position(IconPosition::Start)
+ .icon_size(IconSize::Small)
+ .label_size(LabelSize::Small)
+ .on_click({
+ let editor = editor.clone();
+ let step_range = step_range.clone();
+ move |_, cx| {
+ editor
+ .update(cx, |this, cx| {
+ this.resolve_workflow_step(step_range.clone(), cx)
+ })
+ .ok();
+ }
+ }),
+ )
+ .into_any()
+ }
+
pub(crate) fn into_element(
&self,
step_range: Range<language::Anchor>,
@@ -1469,44 +1465,20 @@ impl WorkflowStepStatus {
|label, delta| label.alpha(delta),
)
.into_any_element(),
-
- WorkflowStepStatus::Error(error) => {
- let error = error.clone();
- h_flex()
- .gap_2()
- .child(
- div()
- .id("step-resolution-failure")
- .child(
- Label::new("Step Resolution Failed")
- .size(LabelSize::Small)
- .color(Color::Error),
- )
- .tooltip(move |cx| Tooltip::text(error.to_string(), cx)),
- )
- .child(
- Button::new(("transform", id), "Retry")
- .icon(IconName::Update)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .label_size(LabelSize::Small)
- .on_click({
- let editor = editor.clone();
- let step_range = step_range.clone();
- move |_, cx| {
- editor
- .update(cx, |this, cx| {
- this.resolve_workflow_step(step_range.clone(), cx)
- })
- .ok();
- }
- }),
- )
- .into_any()
- }
-
+ WorkflowStepStatus::Error(error) => Self::render_workflow_step_error(
+ id,
+ editor.clone(),
+ step_range.clone(),
+ error.to_string(),
+ ),
+ WorkflowStepStatus::Empty => Self::render_workflow_step_error(
+ id,
+ editor.clone(),
+ step_range.clone(),
+ "Model was unable to locate the code to edit".to_string(),
+ ),
WorkflowStepStatus::Idle => Button::new(("transform", id), "Transform")
- .icon(IconName::Sparkle)
+ .icon(IconName::SparkleAlt)
.icon_position(IconPosition::Start)
.icon_size(IconSize::Small)
.label_size(LabelSize::Small)
@@ -1559,7 +1531,8 @@ impl WorkflowStepStatus {
.child(
IconButton::new(("stop-transformation", id), IconName::Stop)
.icon_size(IconSize::Small)
- .style(ButtonStyle::Tinted(TintColor::Negative))
+ .icon_color(Color::Error)
+ .style(ButtonStyle::Subtle)
.tooltip({
let step_range = step_range.clone();
let editor = editor.clone();
@@ -1725,7 +1698,6 @@ pub struct ContextEditor {
active_workflow_step: Option<ActiveWorkflowStep>,
assistant_panel: WeakView<AssistantPanel>,
error_message: Option<SharedString>,
- debug_inspector: Option<ContextInspector>,
show_accept_terms: bool,
}
@@ -1787,7 +1759,6 @@ impl ContextEditor {
active_workflow_step: None,
assistant_panel,
error_message: None,
- debug_inspector: None,
show_accept_terms: false,
};
this.update_message_headers(cx);
@@ -1872,7 +1843,7 @@ impl ContextEditor {
self.confirm_workflow_step(range, cx);
true
}
- WorkflowStepStatus::Error(_) => {
+ WorkflowStepStatus::Error(_) | WorkflowStepStatus::Empty => {
self.resolve_workflow_step(range, cx);
true
}
@@ -1885,9 +1856,8 @@ impl ContextEditor {
range: Range<language::Anchor>,
cx: &mut ViewContext<Self>,
) {
- self.context.update(cx, |context, cx| {
- context.resolve_workflow_step(range, self.project.clone(), cx)
- });
+ self.context
+ .update(cx, |context, cx| context.resolve_workflow_step(range, cx));
}
fn stop_workflow_step(&mut self, range: Range<language::Anchor>, cx: &mut ViewContext<Self>) {
@@ -1998,51 +1968,6 @@ impl ContextEditor {
cx.propagate();
}
- fn debug_workflow_steps(&mut self, _: &DebugWorkflowSteps, cx: &mut ViewContext<Self>) {
- let mut output = String::new();
- for (i, step) in self.context.read(cx).workflow_steps().iter().enumerate() {
- output.push_str(&format!("Step {}:\n", i + 1));
- output.push_str(&format!(
- "Content: {}\n",
- self.context
- .read(cx)
- .buffer()
- .read(cx)
- .text_for_range(step.tagged_range.clone())
- .collect::<String>()
- ));
- match &step.status {
- crate::WorkflowStepStatus::Resolved(ResolvedWorkflowStep {
- title,
- suggestions,
- }) => {
- output.push_str("Resolution:\n");
- output.push_str(&format!(" {:?}\n", title));
- output.push_str(&format!(" {:?}\n", suggestions));
- }
- crate::WorkflowStepStatus::Pending(_) => {
- output.push_str("Resolution: Pending\n");
- }
- crate::WorkflowStepStatus::Error(error) => {
- writeln!(output, "Resolution: Error\n{:?}", error).unwrap();
- }
- }
- output.push('\n');
- }
-
- let editor = self
- .workspace
- .update(cx, |workspace, cx| Editor::new_in_workspace(workspace, cx));
-
- if let Ok(editor) = editor {
- cx.spawn(|_, mut cx| async move {
- let editor = editor.await?;
- editor.update(&mut cx, |editor, cx| editor.set_text(output, cx))
- })
- .detach_and_notify_err(cx);
- }
- }
-
fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
let cursors = self.cursors(cx);
self.context.update(cx, |context, cx| {
@@ -2086,7 +2011,7 @@ impl ContextEditor {
}
editor.insert(&format!("/{name}"), cx);
- if command.requires_argument() {
+ if command.accepts_arguments() {
editor.insert(" ", cx);
editor.show_completions(&ShowCompletions::default(), cx);
}
@@ -2099,6 +2024,10 @@ impl ContextEditor {
}
pub fn confirm_command(&mut self, _: &ConfirmCommand, cx: &mut ViewContext<Self>) {
+ if self.editor.read(cx).has_active_completions_menu() {
+ return;
+ }
+
let selections = self.editor.read(cx).selections.disjoint_anchors();
let mut commands_by_range = HashMap::default();
let workspace = self.workspace.clone();
@@ -2460,9 +2389,6 @@ impl ContextEditor {
blocks_to_remove.insert(step.header_block_id);
blocks_to_remove.insert(step.footer_block_id);
}
- if let Some(debug) = self.debug_inspector.as_mut() {
- debug.deactivate_for(step_range, cx);
- }
}
self.editor.update(cx, |editor, cx| {
editor.remove_blocks(blocks_to_remove, None, cx)
@@ -2481,17 +2407,14 @@ impl ContextEditor {
let Some(step) = self
.context
.read(cx)
- .workflow_step_for_range(step_range.clone())
+ .workflow_step_for_range(step_range.clone(), cx)
else {
return;
};
- let resolved_step = step.status.into_resolved();
+ let resolved_step = step.read(cx).resolution.clone();
if let Some(existing_step) = self.workflow_steps.get_mut(&step_range) {
existing_step.resolved_step = resolved_step;
- if let Some(debug) = self.debug_inspector.as_mut() {
- debug.refresh(&step_range, cx);
- }
} else {
let start = buffer_snapshot
.anchor_in_excerpt(excerpt_id, step_range.start)
@@ -2531,85 +2454,102 @@ impl ContextEditor {
} else {
theme.info_border
};
- let step_index = weak_self.update(&mut **cx, |this, cx| {
- let snapshot = this.editor.read(cx).buffer().read(cx).as_singleton()?.read(cx).text_snapshot();
- let start_offset = step_range.start.to_offset(&snapshot);
- let parent_message = this.context.read(cx).messages_for_offsets([start_offset], cx);
- debug_assert_eq!(parent_message.len(), 1);
- let parent_message = parent_message.first()?;
-
- let index_of_current_step = this.workflow_steps.keys().filter(|workflow_step_range| workflow_step_range.start.cmp(&parent_message.anchor, &snapshot).is_ge() && workflow_step_range.end.cmp(&step_range.end, &snapshot).is_le()).count();
- Some(index_of_current_step)
- }).ok().flatten();
-
- let debug_header = weak_self
- .update(&mut **cx, |this, _| {
- if let Some(inspector) = this.debug_inspector.as_mut() {
- Some(inspector.is_active(&step_range))
- } else {
- None
- }
+ let step_index = weak_self
+ .update(&mut **cx, |this, cx| {
+ let snapshot = this
+ .editor
+ .read(cx)
+ .buffer()
+ .read(cx)
+ .as_singleton()?
+ .read(cx)
+ .text_snapshot();
+ let start_offset =
+ step_range.start.to_offset(&snapshot);
+ let parent_message = this
+ .context
+ .read(cx)
+ .messages_for_offsets([start_offset], cx);
+ debug_assert_eq!(parent_message.len(), 1);
+ let parent_message = parent_message.first()?;
+
+ let index_of_current_step = this
+ .workflow_steps
+ .keys()
+ .filter(|workflow_step_range| {
+ workflow_step_range
+ .start
+ .cmp(&parent_message.anchor, &snapshot)
+ .is_ge()
+ && workflow_step_range
+ .end
+ .cmp(&step_range.end, &snapshot)
+ .is_le()
+ })
+ .count();
+ Some(index_of_current_step)
})
- .unwrap_or_default();
- let step_label = if let Some(index) = step_index {
+ .ok()
+ .flatten();
+ let step_label = if let Some(index) = step_index {
Label::new(format!("Step {index}")).size(LabelSize::Small)
- } else {
- Label::new("Step").size(LabelSize::Small)
- };
- let step_label = if current_status.as_ref().is_some_and(|status| status.is_confirmed()) {
- h_flex().items_center().gap_2().child(step_label.strikethrough(true).color(Color::Muted)).child(Icon::new(IconName::Check).size(IconSize::Small).color(Color::Created))
+ } else {
+ Label::new("Step").size(LabelSize::Small)
+ };
+
+ let step_label = if current_status
+ .as_ref()
+ .is_some_and(|status| status.is_confirmed())
+ {
+ h_flex()
+ .items_center()
+ .gap_2()
+ .child(
+ step_label.strikethrough(true).color(Color::Muted),
+ )
+ .child(
+ Icon::new(IconName::Check)
+ .size(IconSize::Small)
+ .color(Color::Created),
+ )
} else {
div().child(step_label)
};
+
+ let step_label = step_label
+ .id("step")
+ .cursor(CursorStyle::PointingHand)
+ .on_click({
+ let this = weak_self.clone();
+ let step_range = step_range.clone();
+ move |_, cx| {
+ this.update(cx, |this, cx| {
+ this.open_workflow_step(step_range.clone(), cx);
+ })
+ .ok();
+ }
+ });
+
div()
.w_full()
.px(cx.gutter_dimensions.full_width())
.child(
h_flex()
.w_full()
+ .h_8()
.border_b_1()
.border_color(border_color)
- .pb_1p5()
+ .pb_2()
+ .items_center()
.justify_between()
.gap_2()
- .child(h_flex().justify_start().gap_2().child(step_label).children(
- debug_header.map(|is_active| {
-
- Button::new("debug-workflows-toggle", "Debug")
- .icon_color(Color::Hidden)
- .color(Color::Hidden)
- .selected_icon_color(Color::Default)
- .selected_label_color(Color::Default)
- .icon(IconName::Microscope)
- .icon_position(IconPosition::Start)
- .icon_size(IconSize::Small)
- .label_size(LabelSize::Small)
- .selected(is_active)
- .on_click({
- let weak_self = weak_self.clone();
- let step_range = step_range.clone();
- move |_, cx| {
- weak_self
- .update(cx, |this, cx| {
- if let Some(inspector) =
- this.debug_inspector
- .as_mut()
- {
- if is_active {
-
- inspector.deactivate_for(&step_range, cx);
- } else {
- inspector.activate_for_step(step_range.clone(), cx);
- }
- }
- })
- .ok();
- }
- })
- })
-
- ))
+ .child(
+ h_flex()
+ .justify_start()
+ .gap_2()
+ .child(step_label),
+ )
.children(current_status.as_ref().map(|status| {
h_flex().w_full().justify_end().child(
status.into_element(
@@ -2679,6 +2619,38 @@ impl ContextEditor {
self.update_active_workflow_step(cx);
}
+ fn open_workflow_step(
+ &mut self,
+ step_range: Range<language::Anchor>,
+ cx: &mut ViewContext<Self>,
+ ) -> Option<()> {
+ let pane = self
+ .assistant_panel
+ .update(cx, |panel, _| panel.pane())
+ .ok()??;
+ let context = self.context.read(cx);
+ let language_registry = context.language_registry();
+ let step = context.workflow_step_for_range(step_range, cx)?;
+ let context = self.context.clone();
+ cx.deref_mut().defer(move |cx| {
+ pane.update(cx, |pane, cx| {
+ let existing_item = pane
+ .items_of_type::<WorkflowStepView>()
+ .find(|item| *item.read(cx).step() == step.downgrade());
+ if let Some(item) = existing_item {
+ if let Some(index) = pane.index_for_item(&item) {
+ pane.activate_item(index, true, true, cx);
+ }
+ } else {
+ let view = cx
+ .new_view(|cx| WorkflowStepView::new(context, step, language_registry, cx));
+ pane.add_item(Box::new(view), true, true, None, cx);
+ }
+ });
+ });
+ None
+ }
+
fn update_active_workflow_step(&mut self, cx: &mut ViewContext<Self>) {
let new_step = self.active_workflow_step_for_cursor(cx);
if new_step.as_ref() != self.active_workflow_step.as_ref() {
@@ -2793,25 +2765,31 @@ impl ContextEditor {
}
fn open_assists_for_step(
- resolved_step: &ResolvedWorkflowStep,
+ resolved_step: &WorkflowStepResolution,
project: &Model<Project>,
assistant_panel: &WeakView<AssistantPanel>,
workspace: &WeakView<Workspace>,
cx: &mut ViewContext<Self>,
) -> Option<WorkflowAssist> {
let assistant_panel = assistant_panel.upgrade()?;
- if resolved_step.suggestions.is_empty() {
+ if resolved_step.suggestion_groups.is_empty() {
return None;
}
let editor;
let mut editor_was_open = false;
let mut suggestion_groups = Vec::new();
- if resolved_step.suggestions.len() == 1
- && resolved_step.suggestions.values().next().unwrap().len() == 1
+ if resolved_step.suggestion_groups.len() == 1
+ && resolved_step
+ .suggestion_groups
+ .values()
+ .next()
+ .unwrap()
+ .len()
+ == 1
{
// If there's only one buffer and one suggestion group, open it directly
- let (buffer, groups) = resolved_step.suggestions.iter().next().unwrap();
+ let (buffer, groups) = resolved_step.suggestion_groups.iter().next().unwrap();
let group = groups.into_iter().next().unwrap();
editor = workspace
.update(cx, |workspace, cx| {
@@ -2864,7 +2842,7 @@ impl ContextEditor {
let replica_id = project.read(cx).replica_id();
let mut multibuffer = MultiBuffer::new(replica_id, Capability::ReadWrite)
.with_title(resolved_step.title.clone());
- for (buffer, groups) in &resolved_step.suggestions {
+ for (buffer, groups) in &resolved_step.suggestion_groups {
let excerpt_ids = multibuffer.push_excerpts(
buffer.clone(),
groups.iter().map(|suggestion_group| ExcerptRange {
@@ -3545,7 +3523,7 @@ impl ContextEditor {
let button_text = match self.active_workflow_step() {
Some(step) => match step.status(cx) {
WorkflowStepStatus::Resolving => "Resolving Step...",
- WorkflowStepStatus::Error(_) => "Retry Step Resolution",
+ WorkflowStepStatus::Empty | WorkflowStepStatus::Error(_) => "Retry Step Resolution",
WorkflowStepStatus::Idle => "Transform",
WorkflowStepStatus::Pending => "Transforming...",
WorkflowStepStatus::Done => "Accept Transformation",
@@ -3605,28 +3583,11 @@ impl ContextEditor {
fn active_workflow_step_for_cursor(&self, cx: &AppContext) -> Option<ActiveWorkflowStep> {
let newest_cursor = self.editor.read(cx).selections.newest::<usize>(cx).head();
let context = self.context.read(cx);
- let buffer = context.buffer().read(cx);
-
- let workflow_steps = context.workflow_steps();
- workflow_steps
- .binary_search_by(|step| {
- let step_range = step.tagged_range.to_offset(&buffer);
- if newest_cursor < step_range.start {
- Ordering::Greater
- } else if newest_cursor > step_range.end {
- Ordering::Less
- } else {
- Ordering::Equal
- }
- })
- .ok()
- .and_then(|index| {
- let range = workflow_steps[index].tagged_range.clone();
- Some(ActiveWorkflowStep {
- resolved: self.workflow_steps.get(&range)?.resolved_step.is_some(),
- range,
- })
- })
+ let (range, step) = context.workflow_step_containing(newest_cursor, cx)?;
+ Some(ActiveWorkflowStep {
+ resolved: step.read(cx).resolution.is_some(),
+ range,
+ })
}
}
@@ -3654,7 +3615,6 @@ impl Render for ContextEditor {
.capture_action(cx.listener(ContextEditor::confirm_command))
.on_action(cx.listener(ContextEditor::assist))
.on_action(cx.listener(ContextEditor::split))
- .on_action(cx.listener(ContextEditor::debug_workflow_steps))
.size_full()
.children(self.render_notice(cx))
.child(
@@ -3974,12 +3934,13 @@ pub struct ContextEditorToolbarItem {
workspace: WeakView<Workspace>,
active_context_editor: Option<WeakView<ContextEditor>>,
model_summary_editor: View<Editor>,
+ model_selector_menu_handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>,
}
impl ContextEditorToolbarItem {
pub fn new(
workspace: &Workspace,
- _model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
+ model_selector_menu_handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>,
model_summary_editor: View<Editor>,
) -> Self {
Self {
@@ -3987,6 +3948,7 @@ impl ContextEditorToolbarItem {
workspace: workspace.weak_handle(),
active_context_editor: None,
model_summary_editor,
+ model_selector_menu_handle,
}
}
@@ -4120,49 +4082,52 @@ impl Render for ContextEditorToolbarItem {
let right_side = h_flex()
.gap_2()
- .child(ModelSelector::new(
- self.fs.clone(),
- ButtonLike::new("active-model")
- .style(ButtonStyle::Subtle)
- .child(
- h_flex()
- .w_full()
- .gap_0p5()
- .child(
- div()
- .overflow_x_hidden()
- .flex_grow()
- .whitespace_nowrap()
- .child(match (active_provider, active_model) {
- (Some(provider), Some(model)) => h_flex()
- .gap_1()
- .child(
- Icon::new(provider.icon())
- .color(Color::Muted)
- .size(IconSize::XSmall),
- )
- .child(
- Label::new(model.name().0)
- .size(LabelSize::Small)
- .color(Color::Muted),
- )
- .into_any_element(),
- _ => Label::new("No model selected")
- .size(LabelSize::Small)
- .color(Color::Muted)
- .into_any_element(),
- }),
- )
- .child(
- Icon::new(IconName::ChevronDown)
- .color(Color::Muted)
- .size(IconSize::XSmall),
- ),
- )
- .tooltip(move |cx| {
- Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
- }),
- ))
+ .child(
+ ModelSelector::new(
+ self.fs.clone(),
+ ButtonLike::new("active-model")
+ .style(ButtonStyle::Subtle)
+ .child(
+ h_flex()
+ .w_full()
+ .gap_0p5()
+ .child(
+ div()
+ .overflow_x_hidden()
+ .flex_grow()
+ .whitespace_nowrap()
+ .child(match (active_provider, active_model) {
+ (Some(provider), Some(model)) => h_flex()
+ .gap_1()
+ .child(
+ Icon::new(provider.icon())
+ .color(Color::Muted)
+ .size(IconSize::XSmall),
+ )
+ .child(
+ Label::new(model.name().0)
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ )
+ .into_any_element(),
+ _ => Label::new("No model selected")
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .into_any_element(),
+ }),
+ )
+ .child(
+ Icon::new(IconName::ChevronDown)
+ .color(Color::Muted)
+ .size(IconSize::XSmall),
+ ),
+ )
+ .tooltip(move |cx| {
+ Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
+ }),
+ )
+ .with_handle(self.model_selector_menu_handle.clone()),
+ )
.children(self.render_remaining_tokens(cx))
.child(self.render_inject_context_menu(cx));
@@ -1,6 +1,9 @@
+#[cfg(test)]
+mod context_tests;
+
use crate::{
- prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
- InlineAssistId, InlineAssistant, MessageId, MessageStatus,
+ prompts::PromptBuilder, slash_command::SlashCommandLine, workflow::WorkflowStep, MessageId,
+ MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@@ -9,34 +12,25 @@ use assistant_slash_command::{
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
-use editor::Editor;
use fs::{Fs, RemoveOptions};
-use futures::{
- future::{self, Shared},
- stream::FuturesUnordered,
- FutureExt, StreamExt,
-};
+use futures::{future::Shared, stream::FuturesUnordered, FutureExt, StreamExt};
use gpui::{
- AppContext, Context as _, EventEmitter, Image, Model, ModelContext, RenderImage, Subscription,
- Task, UpdateGlobal, View, WeakView,
+ AppContext, Context as _, EventEmitter, Image, Model, ModelContext, RenderImage, SharedString,
+ Subscription, Task,
};
-use language::{
- AnchorRangeExt, Bias, Buffer, BufferSnapshot, LanguageRegistry, OffsetRangeExt, ParseStatus,
- Point, ToOffset,
-};
+use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
use language_model::{
- LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelTool, Role,
+ LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
+ LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
};
use open_ai::Model as OpenAiModel;
use paths::{context_images_dir, contexts_dir};
use project::Project;
-use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use std::{
- cmp::{self, Ordering},
+ cmp::{max, Ordering},
collections::hash_map,
fmt::Debug,
iter, mem,
@@ -46,10 +40,8 @@ use std::{
time::{Duration, Instant},
};
use telemetry_events::AssistantKind;
-use ui::{SharedString, WindowContext};
use util::{post_inc, ResultExt, TryFutureExt};
use uuid::Uuid;
-use workspace::Workspace;
#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ContextId(String);
@@ -115,6 +107,8 @@ impl ContextOperation {
message.status.context("invalid status")?,
),
timestamp: id.0,
+ should_cache: false,
+ is_cache_anchor: false,
},
version: language::proto::deserialize_version(&insert.version),
})
@@ -129,6 +123,8 @@ impl ContextOperation {
timestamp: language::proto::deserialize_timestamp(
update.timestamp.context("invalid timestamp")?,
),
+ should_cache: false,
+ is_cache_anchor: false,
},
version: language::proto::deserialize_version(&update.version),
}),
@@ -319,8 +315,10 @@ pub struct MessageAnchor {
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct MessageMetadata {
pub role: Role,
- status: MessageStatus,
+ pub status: MessageStatus,
timestamp: clock::Lamport,
+ should_cache: bool,
+ is_cache_anchor: bool,
}
#[derive(Clone, Debug)]
@@ -346,21 +344,19 @@ pub struct Message {
pub anchor: language::Anchor,
pub role: Role,
pub status: MessageStatus,
+ pub cache: bool,
}
impl Message {
- fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage {
+ fn to_request_message(&self, buffer: &Buffer) -> Option<LanguageModelRequestMessage> {
let mut content = Vec::new();
let mut range_start = self.offset_range.start;
for (image_offset, message_image) in self.image_offsets.iter() {
if *image_offset != range_start {
- content.push(
- buffer
- .text_for_range(range_start..*image_offset)
- .collect::<String>()
- .into(),
- )
+ if let Some(text) = Self::collect_text_content(buffer, range_start..*image_offset) {
+ content.push(text);
+ }
}
if let Some(image) = message_image.image.clone().now_or_never().flatten() {
@@ -370,17 +366,30 @@ impl Message {
range_start = *image_offset;
}
if range_start != self.offset_range.end {
- content.push(
- buffer
- .text_for_range(range_start..self.offset_range.end)
- .collect::<String>()
- .into(),
- )
+ if let Some(text) =
+ Self::collect_text_content(buffer, range_start..self.offset_range.end)
+ {
+ content.push(text);
+ }
+ }
+
+ if content.is_empty() {
+ return None;
}
- LanguageModelRequestMessage {
+ Some(LanguageModelRequestMessage {
role: self.role,
content,
+ cache: self.cache,
+ })
+ }
+
+ fn collect_text_content(buffer: &Buffer, range: Range<usize>) -> Option<MessageContent> {
+ let text: String = buffer.text_for_range(range.clone()).collect();
+ if text.trim().is_empty() {
+ None
+ } else {
+ Some(MessageContent::Text(text))
}
}
}
@@ -393,12 +402,6 @@ pub struct ImageAnchor {
pub image: Shared<Task<Option<LanguageModelImage>>>,
}
-impl PartialEq for ImageAnchor {
- fn eq(&self, other: &Self) -> bool {
- self.image_id == other.image_id
- }
-}
-
struct PendingCompletion {
id: usize,
assistant_message_id: MessageId,
@@ -408,251 +411,9 @@ struct PendingCompletion {
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub struct SlashCommandId(clock::Lamport);
-#[derive(Debug)]
-pub struct WorkflowStep {
- pub tagged_range: Range<language::Anchor>,
- pub status: WorkflowStepStatus,
-}
-
-#[derive(Clone, Debug, Eq, PartialEq)]
-pub struct ResolvedWorkflowStep {
- pub title: String,
- pub suggestions: HashMap<Model<Buffer>, Vec<WorkflowSuggestionGroup>>,
-}
-
-pub enum WorkflowStepStatus {
- Pending(Task<Option<()>>),
- Resolved(ResolvedWorkflowStep),
- Error(Arc<anyhow::Error>),
-}
-
-impl WorkflowStepStatus {
- pub fn into_resolved(&self) -> Option<Result<ResolvedWorkflowStep, Arc<anyhow::Error>>> {
- match self {
- WorkflowStepStatus::Resolved(resolved) => Some(Ok(resolved.clone())),
- WorkflowStepStatus::Error(error) => Some(Err(error.clone())),
- WorkflowStepStatus::Pending(_) => None,
- }
- }
-}
-
-#[derive(Clone, Debug, Eq, PartialEq)]
-pub struct WorkflowSuggestionGroup {
- pub context_range: Range<language::Anchor>,
- pub suggestions: Vec<WorkflowSuggestion>,
-}
-
-#[derive(Clone, Debug, Eq, PartialEq)]
-pub enum WorkflowSuggestion {
- Update {
- range: Range<language::Anchor>,
- description: String,
- },
- CreateFile {
- description: String,
- },
- InsertSiblingBefore {
- position: language::Anchor,
- description: String,
- },
- InsertSiblingAfter {
- position: language::Anchor,
- description: String,
- },
- PrependChild {
- position: language::Anchor,
- description: String,
- },
- AppendChild {
- position: language::Anchor,
- description: String,
- },
- Delete {
- range: Range<language::Anchor>,
- },
-}
-
-impl WorkflowSuggestion {
- pub fn range(&self) -> Range<language::Anchor> {
- match self {
- WorkflowSuggestion::Update { range, .. } => range.clone(),
- WorkflowSuggestion::CreateFile { .. } => language::Anchor::MIN..language::Anchor::MAX,
- WorkflowSuggestion::InsertSiblingBefore { position, .. }
- | WorkflowSuggestion::InsertSiblingAfter { position, .. }
- | WorkflowSuggestion::PrependChild { position, .. }
- | WorkflowSuggestion::AppendChild { position, .. } => *position..*position,
- WorkflowSuggestion::Delete { range } => range.clone(),
- }
- }
-
- pub fn description(&self) -> Option<&str> {
- match self {
- WorkflowSuggestion::Update { description, .. }
- | WorkflowSuggestion::CreateFile { description }
- | WorkflowSuggestion::InsertSiblingBefore { description, .. }
- | WorkflowSuggestion::InsertSiblingAfter { description, .. }
- | WorkflowSuggestion::PrependChild { description, .. }
- | WorkflowSuggestion::AppendChild { description, .. } => Some(description),
- WorkflowSuggestion::Delete { .. } => None,
- }
- }
-
- fn description_mut(&mut self) -> Option<&mut String> {
- match self {
- WorkflowSuggestion::Update { description, .. }
- | WorkflowSuggestion::CreateFile { description }
- | WorkflowSuggestion::InsertSiblingBefore { description, .. }
- | WorkflowSuggestion::InsertSiblingAfter { description, .. }
- | WorkflowSuggestion::PrependChild { description, .. }
- | WorkflowSuggestion::AppendChild { description, .. } => Some(description),
- WorkflowSuggestion::Delete { .. } => None,
- }
- }
-
- fn try_merge(&mut self, other: &Self, buffer: &BufferSnapshot) -> bool {
- let range = self.range();
- let other_range = other.range();
-
- // Don't merge if we don't contain the other suggestion.
- if range.start.cmp(&other_range.start, buffer).is_gt()
- || range.end.cmp(&other_range.end, buffer).is_lt()
- {
- return false;
- }
-
- if let Some(description) = self.description_mut() {
- if let Some(other_description) = other.description() {
- description.push('\n');
- description.push_str(other_description);
- }
- }
- true
- }
-
- pub fn show(
- &self,
- editor: &View<Editor>,
- excerpt_id: editor::ExcerptId,
- workspace: &WeakView<Workspace>,
- assistant_panel: &View<AssistantPanel>,
- cx: &mut WindowContext,
- ) -> Option<InlineAssistId> {
- let mut initial_transaction_id = None;
- let initial_prompt;
- let suggestion_range;
- let buffer = editor.read(cx).buffer().clone();
- let snapshot = buffer.read(cx).snapshot(cx);
-
- match self {
- WorkflowSuggestion::Update { range, description } => {
- initial_prompt = description.clone();
- suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
- ..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
- }
- WorkflowSuggestion::CreateFile { description } => {
- initial_prompt = description.clone();
- suggestion_range = editor::Anchor::min()..editor::Anchor::min();
- }
- WorkflowSuggestion::InsertSiblingBefore {
- position,
- description,
- } => {
- let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
- initial_prompt = description.clone();
- suggestion_range = buffer.update(cx, |buffer, cx| {
- buffer.start_transaction(cx);
- let line_start = buffer.insert_empty_line(position, true, true, cx);
- initial_transaction_id = buffer.end_transaction(cx);
- buffer.refresh_preview(cx);
-
- let line_start = buffer.read(cx).anchor_before(line_start);
- line_start..line_start
- });
- }
- WorkflowSuggestion::InsertSiblingAfter {
- position,
- description,
- } => {
- let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
- initial_prompt = description.clone();
- suggestion_range = buffer.update(cx, |buffer, cx| {
- buffer.start_transaction(cx);
- let line_start = buffer.insert_empty_line(position, true, true, cx);
- initial_transaction_id = buffer.end_transaction(cx);
- buffer.refresh_preview(cx);
-
- let line_start = buffer.read(cx).anchor_before(line_start);
- line_start..line_start
- });
- }
- WorkflowSuggestion::PrependChild {
- position,
- description,
- } => {
- let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
- initial_prompt = description.clone();
- suggestion_range = buffer.update(cx, |buffer, cx| {
- buffer.start_transaction(cx);
- let line_start = buffer.insert_empty_line(position, false, true, cx);
- initial_transaction_id = buffer.end_transaction(cx);
- buffer.refresh_preview(cx);
-
- let line_start = buffer.read(cx).anchor_before(line_start);
- line_start..line_start
- });
- }
- WorkflowSuggestion::AppendChild {
- position,
- description,
- } => {
- let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
- initial_prompt = description.clone();
- suggestion_range = buffer.update(cx, |buffer, cx| {
- buffer.start_transaction(cx);
- let line_start = buffer.insert_empty_line(position, true, false, cx);
- initial_transaction_id = buffer.end_transaction(cx);
- buffer.refresh_preview(cx);
-
- let line_start = buffer.read(cx).anchor_before(line_start);
- line_start..line_start
- });
- }
- WorkflowSuggestion::Delete { range } => {
- initial_prompt = "Delete".to_string();
- suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
- ..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
- }
- }
-
- InlineAssistant::update_global(cx, |inline_assistant, cx| {
- Some(inline_assistant.suggest_assist(
- editor,
- suggestion_range,
- initial_prompt,
- initial_transaction_id,
- Some(workspace.clone()),
- Some(assistant_panel),
- cx,
- ))
- })
- }
-}
-
-impl Debug for WorkflowStepStatus {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- WorkflowStepStatus::Pending(_) => write!(f, "WorkflowStepStatus::Pending"),
- WorkflowStepStatus::Resolved(ResolvedWorkflowStep { title, suggestions }) => f
- .debug_struct("WorkflowStepStatus::Resolved")
- .field("title", title)
- .field("suggestions", suggestions)
- .finish(),
- WorkflowStepStatus::Error(error) => f
- .debug_tuple("WorkflowStepStatus::Error")
- .field(error)
- .finish(),
- }
- }
+struct WorkflowStepEntry {
+ range: Range<language::Anchor>,
+ step: Model<WorkflowStep>,
}
pub struct Context {
@@ -677,11 +438,12 @@ pub struct Context {
token_count: Option<usize>,
pending_token_count: Task<Option<()>>,
pending_save: Task<Result<()>>,
+ pending_cache_warming_task: Task<Option<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
telemetry: Option<Arc<Telemetry>>,
language_registry: Arc<LanguageRegistry>,
- workflow_steps: Vec<WorkflowStep>,
+ workflow_steps: Vec<WorkflowStepEntry>,
edits_since_last_workflow_step_prune: language::Subscription,
project: Option<Model<Project>>,
prompt_builder: Arc<PromptBuilder>,
@@ -754,6 +516,7 @@ impl Context {
pending_completions: Default::default(),
token_count: None,
pending_token_count: Task::ready(None),
+ pending_cache_warming_task: Task::ready(None),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
@@ -780,6 +543,8 @@ impl Context {
role: Role::User,
status: MessageStatus::Done,
timestamp: first_message_id.0,
+ should_cache: false,
+ is_cache_anchor: false,
},
);
this.message_anchors.push(message);
@@ -789,7 +554,7 @@ impl Context {
this
}
- fn serialize(&self, cx: &AppContext) -> SavedContext {
+ pub(crate) fn serialize(&self, cx: &AppContext) -> SavedContext {
let buffer = self.buffer.read(cx);
SavedContext {
id: Some(self.id.clone()),
@@ -1082,6 +847,18 @@ impl Context {
&self.buffer
}
+ pub fn language_registry(&self) -> Arc<LanguageRegistry> {
+ self.language_registry.clone()
+ }
+
+ pub fn project(&self) -> Option<Model<Project>> {
+ self.project.clone()
+ }
+
+ pub fn prompt_builder(&self) -> Arc<PromptBuilder> {
+ self.prompt_builder.clone()
+ }
+
pub fn path(&self) -> Option<&Path> {
self.path.as_deref()
}
@@ -1090,14 +867,46 @@ impl Context {
self.summary.as_ref()
}
- pub fn workflow_steps(&self) -> &[WorkflowStep] {
- &self.workflow_steps
+ pub fn workflow_step_containing(
+ &self,
+ offset: usize,
+ cx: &AppContext,
+ ) -> Option<(Range<language::Anchor>, Model<WorkflowStep>)> {
+ let buffer = self.buffer.read(cx);
+ let index = self
+ .workflow_steps
+ .binary_search_by(|step| {
+ let step_range = step.range.to_offset(&buffer);
+ if offset < step_range.start {
+ Ordering::Greater
+ } else if offset > step_range.end {
+ Ordering::Less
+ } else {
+ Ordering::Equal
+ }
+ })
+ .ok()?;
+ let step = &self.workflow_steps[index];
+ Some((step.range.clone(), step.step.clone()))
+ }
+
+ pub fn workflow_step_for_range(
+ &self,
+ range: Range<language::Anchor>,
+ cx: &AppContext,
+ ) -> Option<Model<WorkflowStep>> {
+ let buffer = self.buffer.read(cx);
+ let index = self.workflow_step_index_for_range(&range, buffer).ok()?;
+ Some(self.workflow_steps[index].step.clone())
}
- pub fn workflow_step_for_range(&self, range: Range<language::Anchor>) -> Option<&WorkflowStep> {
+ pub fn workflow_step_index_for_range(
+ &self,
+ tagged_range: &Range<text::Anchor>,
+ buffer: &text::BufferSnapshot,
+ ) -> Result<usize, usize> {
self.workflow_steps
- .iter()
- .find(|step| step.tagged_range == range)
+ .binary_search_by(|probe| probe.range.cmp(&tagged_range, buffer))
}
pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
@@ -1160,6 +969,7 @@ impl Context {
let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
this.update(&mut cx, |this, cx| {
this.token_count = Some(token_count);
+ this.start_cache_warming(&model, cx);
cx.notify()
})
}
@@ -1167,6 +977,121 @@ impl Context {
});
}
+ pub fn mark_longest_messages_for_cache(
+ &mut self,
+ cache_configuration: &Option<LanguageModelCacheConfiguration>,
+ speculative: bool,
+ cx: &mut ModelContext<Self>,
+ ) -> bool {
+ let cache_configuration =
+ cache_configuration
+ .as_ref()
+ .unwrap_or(&LanguageModelCacheConfiguration {
+ max_cache_anchors: 0,
+ should_speculate: false,
+ min_total_token: 0,
+ });
+
+ let messages: Vec<Message> = self
+ .messages_from_anchors(
+ self.message_anchors.iter().take(if speculative {
+ self.message_anchors.len().saturating_sub(1)
+ } else {
+ self.message_anchors.len()
+ }),
+ cx,
+ )
+ .filter(|message| message.offset_range.len() >= 5_000)
+ .collect();
+
+ let mut sorted_messages = messages.clone();
+ sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
+ if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
+ // Some models support caching, but don't support anchors. In that case we want to
+ // mark the largest message as needing to be cached, but we will not mark it as an
+ // anchor.
+ sorted_messages.truncate(1);
+ } else {
+ // Save 1 anchor for the inline assistant.
+ sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
+ }
+
+ let longest_message_ids: HashSet<MessageId> = sorted_messages
+ .into_iter()
+ .map(|message| message.id)
+ .collect();
+
+ let cache_deltas: HashSet<MessageId> = self
+ .messages_metadata
+ .iter()
+ .filter_map(|(id, metadata)| {
+ let should_cache = longest_message_ids.contains(id);
+ let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
+ if metadata.should_cache != should_cache
+ || metadata.is_cache_anchor != should_be_anchor
+ {
+ Some(*id)
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ let mut newly_cached_item = false;
+ for id in cache_deltas {
+ newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
+ self.update_metadata(id, cx, |metadata| {
+ metadata.should_cache = longest_message_ids.contains(&id);
+ metadata.is_cache_anchor =
+ metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
+ });
+ }
+ newly_cached_item
+ }
+
+ fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
+ let cache_configuration = model.cache_configuration();
+ if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
+ return;
+ }
+ if let Some(cache_configuration) = cache_configuration {
+ if !cache_configuration.should_speculate {
+ return;
+ }
+ }
+
+ let request = {
+ let mut req = self.to_completion_request(cx);
+ // Skip the last message because it's likely to change and
+ // therefore would be a waste to cache.
+ req.messages.pop();
+ req.messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec!["Respond only with OK, nothing else.".into()],
+ cache: false,
+ });
+ req
+ };
+
+ let model = Arc::clone(model);
+ self.pending_cache_warming_task = cx.spawn(|_, cx| {
+ async move {
+ match model.stream_completion(request, &cx).await {
+ Ok(mut stream) => {
+ stream.next().await;
+ log::info!("Cache warming completed successfully");
+ }
+ Err(e) => {
+ log::warn!("Cache warming failed: {}", e);
+ }
+ };
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
+ }
+
pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
let buffer = self.buffer.read(cx);
let mut row_ranges = self
@@ -1260,7 +1185,7 @@ impl Context {
removed.extend(
self.workflow_steps
.drain(intersecting_range)
- .map(|step| step.tagged_range),
+ .map(|step| step.range),
);
}
@@ -1279,7 +1204,7 @@ impl Context {
let buffer = self.buffer.read(cx);
let start_ix = match self.workflow_steps.binary_search_by(|probe| {
probe
- .tagged_range
+ .range
.end
.to_offset(buffer)
.cmp(&range.start)
@@ -1293,7 +1218,7 @@ impl Context {
};
let end_ix = match self.workflow_steps.binary_search_by(|probe| {
probe
- .tagged_range
+ .range
.start
.to_offset(buffer)
.cmp(&range.end)
@@ -1308,12 +1233,8 @@ impl Context {
start_ix..end_ix
}
- fn parse_workflow_steps_in_range(
- &mut self,
- range: Range<usize>,
- project: Model<Project>,
- cx: &mut ModelContext<Self>,
- ) {
+ fn parse_workflow_steps_in_range(&mut self, range: Range<usize>, cx: &mut ModelContext<Self>) {
+ let weak_self = cx.weak_model();
let mut new_edit_steps = Vec::new();
let mut edits = Vec::new();
@@ -1348,16 +1269,17 @@ impl Context {
..buffer.anchor_before(step_end_tag_start_ix);
// Check if a step with the same range already exists
- let existing_step_index = self
- .workflow_steps
- .binary_search_by(|probe| probe.tagged_range.cmp(&tagged_range, &buffer));
+ let existing_step_index =
+ self.workflow_step_index_for_range(&tagged_range, &buffer);
if let Err(ix) = existing_step_index {
new_edit_steps.push((
ix,
- WorkflowStep {
- tagged_range,
- status: WorkflowStepStatus::Pending(Task::ready(None)),
+ WorkflowStepEntry {
+ step: cx.new_model(|_| {
+ WorkflowStep::new(tagged_range.clone(), weak_self.clone())
+ }),
+ range: tagged_range,
},
));
}
@@ -1371,10 +1293,10 @@ impl Context {
let mut updated = Vec::new();
for (index, step) in new_edit_steps.into_iter().rev() {
- let step_range = step.tagged_range.clone();
+ let step_range = step.range.clone();
updated.push(step_range.clone());
self.workflow_steps.insert(index, step);
- self.resolve_workflow_step(step_range, project.clone(), cx);
+ self.resolve_workflow_step(step_range, cx);
}
// Delete <step> tags, making sure we don't accidentally invalidate
@@ -1387,161 +1309,30 @@ impl Context {
pub fn resolve_workflow_step(
&mut self,
tagged_range: Range<language::Anchor>,
- project: Model<Project>,
cx: &mut ModelContext<Self>,
) {
let Ok(step_index) = self
.workflow_steps
- .binary_search_by(|step| step.tagged_range.cmp(&tagged_range, self.buffer.read(cx)))
+ .binary_search_by(|step| step.range.cmp(&tagged_range, self.buffer.read(cx)))
else {
return;
};
- let mut request = self.to_completion_request(cx);
- let Some(edit_step) = self.workflow_steps.get_mut(step_index) else {
- return;
- };
-
- if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
- let step_text = self
- .buffer
- .read(cx)
- .text_for_range(tagged_range.clone())
- .collect::<String>();
-
- let tagged_range = tagged_range.clone();
- edit_step.status = WorkflowStepStatus::Pending(cx.spawn(|this, mut cx| {
- async move {
- let result = async {
- let mut prompt = this.update(&mut cx, |this, _| {
- this.prompt_builder.generate_step_resolution_prompt()
- })??;
- prompt.push_str(&step_text);
-
- request.messages.push(LanguageModelRequestMessage {
- role: Role::User,
- content: vec![prompt.into()],
- });
-
- // Invoke the model to get its edit suggestions for this workflow step.
- let resolution = model
- .use_tool::<tool::WorkflowStepResolution>(request, &cx)
- .await?;
-
- // Translate the parsed suggestions to our internal types, which anchor the suggestions to locations in the code.
- let suggestion_tasks: Vec<_> = resolution
- .suggestions
- .iter()
- .map(|suggestion| suggestion.resolve(project.clone(), cx.clone()))
- .collect();
-
- // Expand the context ranges of each suggestion and group suggestions with overlapping context ranges.
- let suggestions = future::join_all(suggestion_tasks)
- .await
- .into_iter()
- .filter_map(|task| task.log_err())
- .collect::<Vec<_>>();
-
- let mut suggestions_by_buffer = HashMap::default();
- for (buffer, suggestion) in suggestions {
- suggestions_by_buffer
- .entry(buffer)
- .or_insert_with(Vec::new)
- .push(suggestion);
- }
-
- let mut suggestion_groups_by_buffer = HashMap::default();
- for (buffer, mut suggestions) in suggestions_by_buffer {
- let mut suggestion_groups = Vec::<WorkflowSuggestionGroup>::new();
- let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
- // Sort suggestions by their range so that earlier, larger ranges come first
- suggestions.sort_by(|a, b| a.range().cmp(&b.range(), &snapshot));
-
- // Merge overlapping suggestions
- suggestions.dedup_by(|a, b| b.try_merge(&a, &snapshot));
-
- // Create context ranges for each suggestion
- for suggestion in suggestions {
- let context_range = {
- let suggestion_point_range =
- suggestion.range().to_point(&snapshot);
- let start_row =
- suggestion_point_range.start.row.saturating_sub(5);
- let end_row = cmp::min(
- suggestion_point_range.end.row + 5,
- snapshot.max_point().row,
- );
- let start = snapshot.anchor_before(Point::new(start_row, 0));
- let end = snapshot.anchor_after(Point::new(
- end_row,
- snapshot.line_len(end_row),
- ));
- start..end
- };
-
- if let Some(last_group) = suggestion_groups.last_mut() {
- if last_group
- .context_range
- .end
- .cmp(&context_range.start, &snapshot)
- .is_ge()
- {
- // Merge with the previous group if context ranges overlap
- last_group.context_range.end = context_range.end;
- last_group.suggestions.push(suggestion);
- } else {
- // Create a new group
- suggestion_groups.push(WorkflowSuggestionGroup {
- context_range,
- suggestions: vec![suggestion],
- });
- }
- } else {
- // Create the first group
- suggestion_groups.push(WorkflowSuggestionGroup {
- context_range,
- suggestions: vec![suggestion],
- });
- }
- }
-
- suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
- }
-
- Ok((resolution.step_title, suggestion_groups_by_buffer))
- };
+ cx.emit(ContextEvent::WorkflowStepUpdated(tagged_range.clone()));
+ cx.notify();
- let result = result.await;
- this.update(&mut cx, |this, cx| {
- let step_index = this
- .workflow_steps
- .binary_search_by(|step| {
- step.tagged_range.cmp(&tagged_range, this.buffer.read(cx))
- })
- .map_err(|_| anyhow!("edit step not found"))?;
- if let Some(edit_step) = this.workflow_steps.get_mut(step_index) {
- edit_step.status = match result {
- Ok((title, suggestions)) => {
- WorkflowStepStatus::Resolved(ResolvedWorkflowStep {
- title,
- suggestions,
- })
- }
- Err(error) => WorkflowStepStatus::Error(Arc::new(error)),
- };
- cx.emit(ContextEvent::WorkflowStepUpdated(tagged_range));
- cx.notify();
- }
- anyhow::Ok(())
- })?
- }
- .log_err()
- }));
- } else {
- edit_step.status = WorkflowStepStatus::Error(Arc::new(anyhow!("no active model")));
- }
+ let resolution = self.workflow_steps[step_index].step.clone();
+ cx.defer(move |cx| {
+ resolution.update(cx, |resolution, cx| resolution.resolve(cx));
+ });
+ }
- cx.emit(ContextEvent::WorkflowStepUpdated(tagged_range));
+ pub fn workflow_step_updated(
+ &mut self,
+ range: Range<language::Anchor>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ cx.emit(ContextEvent::WorkflowStepUpdated(range));
cx.notify();
}
@@ -1698,20 +1489,26 @@ impl Context {
self.count_remaining_tokens(cx);
}
- pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
- let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
- let model = LanguageModelRegistry::read_global(cx).active_model()?;
- let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
+ fn get_last_valid_message_id(&self, cx: &ModelContext<Self>) -> Option<MessageId> {
+ self.message_anchors.iter().rev().find_map(|message| {
message
.start
.is_valid(self.buffer.read(cx))
.then_some(message.id)
- })?;
+ })
+ }
+
+ pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
+ let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
+ let model = LanguageModelRegistry::read_global(cx).active_model()?;
+ let last_message_id = self.get_last_valid_message_id(cx)?;
if !provider.is_authenticated(cx) {
log::info!("completion provider has no credentials");
return None;
}
+ // Compute which messages to cache, including the last one.
+ self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
let request = self.to_completion_request(cx);
let assistant_message = self
@@ -1762,11 +1559,10 @@ impl Context {
);
message_start_offset..message_new_end_offset
});
- if let Some(project) = this.project.clone() {
- // Use `inclusive = false` as edits might occur at the end of a parsed step.
- this.prune_invalid_workflow_steps(false, cx);
- this.parse_workflow_steps_in_range(message_range, project, cx);
- }
+
+ // Use `inclusive = false` as edits might occur at the end of a parsed step.
+ this.prune_invalid_workflow_steps(false, cx);
+ this.parse_workflow_steps_in_range(message_range, cx);
cx.emit(ContextEvent::StreamedCompletion);
Some(())
@@ -1832,7 +1628,7 @@ impl Context {
let request_messages = self
.messages(cx)
.filter(|message| message.status == MessageStatus::Done)
- .map(|message| message.to_request_message(&buffer))
+ .filter_map(|message| message.to_request_message(&buffer))
.collect();
LanguageModelRequest {
@@ -1886,7 +1682,7 @@ impl Context {
}
}
- fn insert_message_after(
+ pub fn insert_message_after(
&mut self,
message_id: MessageId,
role: Role,
@@ -1927,6 +1723,8 @@ impl Context {
role,
status,
timestamp: anchor.id.0,
+ should_cache: false,
+ is_cache_anchor: false,
};
self.insert_message(anchor.clone(), metadata.clone(), cx);
self.push_op(
@@ -2043,6 +1841,8 @@ impl Context {
role,
status: MessageStatus::Done,
timestamp: suffix.id.0,
+ should_cache: false,
+ is_cache_anchor: false,
};
self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
self.push_op(
@@ -0,0 +1,1055 @@
+use crate::{
+ assistant_panel, prompt_library, slash_command::file_command, workflow::tool, Context,
+ ContextEvent, ContextId, ContextOperation, MessageId, MessageStatus, PromptBuilder,
+};
+use anyhow::Result;
+use assistant_slash_command::{
+ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
+ SlashCommandRegistry,
+};
+use collections::HashSet;
+use fs::{FakeFs, Fs as _};
+use gpui::{AppContext, Model, SharedString, Task, TestAppContext, WeakView};
+use indoc::indoc;
+use language::{Buffer, LanguageRegistry, LspAdapterDelegate};
+use language_model::{LanguageModelRegistry, Role};
+use parking_lot::Mutex;
+use project::Project;
+use rand::prelude::*;
+use rope::Point;
+use serde_json::json;
+use settings::SettingsStore;
+use std::{
+ cell::RefCell,
+ env,
+ ops::Range,
+ path::Path,
+ rc::Rc,
+ sync::{atomic::AtomicBool, Arc},
+};
+use text::{network::Network, OffsetRangeExt as _, ReplicaId, ToPoint as _};
+use ui::{Context as _, WindowContext};
+use unindent::Unindent;
+use util::{test::marked_text_ranges, RandomCharIter};
+use workspace::Workspace;
+
+#[gpui::test]
+fn test_inserting_and_removing_messages(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ LanguageModelRegistry::test(cx);
+ cx.set_global(settings_store);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context =
+ cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&context, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ let message_2 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..1),
+ (message_2.id, Role::Assistant, 1..1)
+ ]
+ );
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..3)
+ ]
+ );
+
+ let message_3 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_3.id, Role::User, 4..4)
+ ]
+ );
+
+ let message_4 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_4.id, Role::User, 4..5),
+ (message_3.id, Role::User, 5..5),
+ ]
+ );
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_4.id, Role::User, 4..6),
+ (message_3.id, Role::User, 6..7),
+ ]
+ );
+
+ // Deleting across message boundaries merges the messages.
+ buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_3.id, Role::User, 3..4),
+ ]
+ );
+
+ // Undoing the deletion should also undo the merge.
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..2),
+ (message_2.id, Role::Assistant, 2..4),
+ (message_4.id, Role::User, 4..6),
+ (message_3.id, Role::User, 6..7),
+ ]
+ );
+
+ // Redoing the deletion should also redo the merge.
+ buffer.update(cx, |buffer, cx| buffer.redo(cx));
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_3.id, Role::User, 3..4),
+ ]
+ );
+
+ // Ensure we can still insert after a merged message.
+ let message_5 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..3),
+ (message_5.id, Role::System, 3..4),
+ (message_3.id, Role::User, 4..5)
+ ]
+ );
+}
+
+#[gpui::test]
+fn test_message_splitting(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ LanguageModelRegistry::test(cx);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context =
+ cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&context, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
+ });
+
+ let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
+ let message_2 = message_2.unwrap();
+
+ // We recycle newlines in the middle of a split message
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..16),
+ ]
+ );
+
+ let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
+ let message_3 = message_3.unwrap();
+
+ // We don't recycle newlines at the end of a split message
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..17),
+ ]
+ );
+
+ let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
+ let message_4 = message_4.unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..9),
+ (message_4.id, Role::User, 9..17),
+ ]
+ );
+
+ let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
+ let message_5 = message_5.unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..9),
+ (message_4.id, Role::User, 9..10),
+ (message_5.id, Role::User, 10..18),
+ ]
+ );
+
+ let (message_6, message_7) =
+ context.update(cx, |context, cx| context.split_message(14..16, cx));
+ let message_6 = message_6.unwrap();
+ let message_7 = message_7.unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_3.id, Role::User, 4..5),
+ (message_2.id, Role::User, 5..9),
+ (message_4.id, Role::User, 9..10),
+ (message_5.id, Role::User, 10..14),
+ (message_6.id, Role::User, 14..17),
+ (message_7.id, Role::User, 17..19),
+ ]
+ );
+}
+
+#[gpui::test]
+fn test_messages_for_offsets(cx: &mut AppContext) {
+ let settings_store = SettingsStore::test(cx);
+ LanguageModelRegistry::test(cx);
+ cx.set_global(settings_store);
+ assistant_panel::init(cx);
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context =
+ cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
+ let buffer = context.read(cx).buffer.clone();
+
+ let message_1 = context.read(cx).message_anchors[0].clone();
+ assert_eq!(
+ messages(&context, cx),
+ vec![(message_1.id, Role::User, 0..0)]
+ );
+
+ buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
+ let message_2 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
+
+ let message_3 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
+
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..8),
+ (message_3.id, Role::User, 8..11)
+ ]
+ );
+
+ assert_eq!(
+ message_ids_for_offsets(&context, &[0, 4, 9], cx),
+ [message_1.id, message_2.id, message_3.id]
+ );
+ assert_eq!(
+ message_ids_for_offsets(&context, &[0, 1, 11], cx),
+ [message_1.id, message_3.id]
+ );
+
+ let message_4 = context
+ .update(cx, |context, cx| {
+ context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
+ })
+ .unwrap();
+ assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
+ assert_eq!(
+ messages(&context, cx),
+ vec![
+ (message_1.id, Role::User, 0..4),
+ (message_2.id, Role::User, 4..8),
+ (message_3.id, Role::User, 8..12),
+ (message_4.id, Role::User, 12..12)
+ ]
+ );
+ assert_eq!(
+ message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
+ [message_1.id, message_2.id, message_3.id, message_4.id]
+ );
+
+ fn message_ids_for_offsets(
+ context: &Model<Context>,
+ offsets: &[usize],
+ cx: &AppContext,
+ ) -> Vec<MessageId> {
+ context
+ .read(cx)
+ .messages_for_offsets(offsets.iter().copied(), cx)
+ .into_iter()
+ .map(|message| message.id)
+ .collect()
+ }
+}
+
+#[gpui::test]
+async fn test_slash_commands(cx: &mut TestAppContext) {
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(LanguageModelRegistry::test);
+ cx.update(Project::init_settings);
+ cx.update(assistant_panel::init);
+ let fs = FakeFs::new(cx.background_executor.clone());
+
+ fs.insert_tree(
+ "/test",
+ json!({
+ "src": {
+ "lib.rs": "fn one() -> usize { 1 }",
+ "main.rs": "
+ use crate::one;
+ fn main() { one(); }
+ ".unindent(),
+ }
+ }),
+ )
+ .await;
+
+ let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
+ slash_command_registry.register_command(file_command::FileSlashCommand, false);
+
+ let registry = Arc::new(LanguageRegistry::test(cx.executor()));
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context =
+ cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
+
+ let output_ranges = Rc::new(RefCell::new(HashSet::default()));
+ context.update(cx, |_, cx| {
+ cx.subscribe(&context, {
+ let ranges = output_ranges.clone();
+ move |_, _, event, _| match event {
+ ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
+ for range in removed {
+ ranges.borrow_mut().remove(range);
+ }
+ for command in updated {
+ ranges.borrow_mut().insert(command.source_range.clone());
+ }
+ }
+ _ => {}
+ }
+ })
+ .detach();
+ });
+
+ let buffer = context.read_with(cx, |context, _| context.buffer.clone());
+
+ // Insert a slash command
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
+ });
+ assert_text_and_output_ranges(
+ &buffer,
+ &output_ranges.borrow(),
+ "
+ «/file src/lib.rs»
+ "
+ .unindent()
+ .trim_end(),
+ cx,
+ );
+
+ // Edit the argument of the slash command.
+ buffer.update(cx, |buffer, cx| {
+ let edit_offset = buffer.text().find("lib.rs").unwrap();
+ buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
+ });
+ assert_text_and_output_ranges(
+ &buffer,
+ &output_ranges.borrow(),
+ "
+ «/file src/main.rs»
+ "
+ .unindent()
+ .trim_end(),
+ cx,
+ );
+
+ // Edit the name of the slash command, using one that doesn't exist.
+ buffer.update(cx, |buffer, cx| {
+ let edit_offset = buffer.text().find("/file").unwrap();
+ buffer.edit(
+ [(edit_offset..edit_offset + "/file".len(), "/unknown")],
+ None,
+ cx,
+ );
+ });
+ assert_text_and_output_ranges(
+ &buffer,
+ &output_ranges.borrow(),
+ "
+ /unknown src/main.rs
+ "
+ .unindent()
+ .trim_end(),
+ cx,
+ );
+
+ #[track_caller]
+ fn assert_text_and_output_ranges(
+ buffer: &Model<Buffer>,
+ ranges: &HashSet<Range<language::Anchor>>,
+ expected_marked_text: &str,
+ cx: &mut TestAppContext,
+ ) {
+ let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
+ let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
+ let mut ranges = ranges
+ .iter()
+ .map(|range| range.to_offset(buffer))
+ .collect::<Vec<_>>();
+ ranges.sort_by_key(|a| a.start);
+ (buffer.text(), ranges)
+ });
+
+ assert_eq!(actual_text, expected_text);
+ assert_eq!(actual_ranges, expected_ranges);
+ }
+}
+
+#[gpui::test]
+async fn test_edit_step_parsing(cx: &mut TestAppContext) {
+ cx.update(prompt_library::init);
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(Project::init_settings);
+ let fs = FakeFs::new(cx.executor());
+ fs.as_fake()
+ .insert_tree(
+ "/root",
+ json!({
+ "hello.rs": r#"
+ fn hello() {
+ println!("Hello, World!");
+ }
+ "#.unindent()
+ }),
+ )
+ .await;
+ let project = Project::test(fs, [Path::new("/root")], cx).await;
+ cx.update(LanguageModelRegistry::test);
+
+ let model = cx.read(|cx| {
+ LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .unwrap()
+ });
+ cx.update(assistant_panel::init);
+ let registry = Arc::new(LanguageRegistry::test(cx.executor()));
+
+ // Create a new context
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context = cx.new_model(|cx| {
+ Context::local(
+ registry.clone(),
+ Some(project),
+ None,
+ prompt_builder.clone(),
+ cx,
+ )
+ });
+ let buffer = context.read_with(cx, |context, _| context.buffer.clone());
+
+ // Simulate user input
+ let user_message = indoc! {r#"
+ Please add unnecessary complexity to this code:
+
+ ```hello.rs
+ fn main() {
+ println!("Hello, World!");
+ }
+ ```
+ "#};
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, user_message)], None, cx);
+ });
+
+ // Simulate LLM response with edit steps
+ let llm_response = indoc! {r#"
+ Sure, I can help you with that. Here's a step-by-step process:
+
+ <step>
+ First, let's extract the greeting into a separate function:
+
+ ```rust
+ fn greet() {
+ println!("Hello, World!");
+ }
+
+ fn main() {
+ greet();
+ }
+ ```
+ </step>
+
+ <step>
+ Now, let's make the greeting customizable:
+
+ ```rust
+ fn greet(name: &str) {
+ println!("Hello, {}!", name);
+ }
+
+ fn main() {
+ greet("World");
+ }
+ ```
+ </step>
+
+ These changes make the code more modular and flexible.
+ "#};
+
+ // Simulate the assist method to trigger the LLM response
+ context.update(cx, |context, cx| context.assist(cx));
+ cx.run_until_parked();
+
+ // Retrieve the assistant response message's start from the context
+ let response_start_row = context.read_with(cx, |context, cx| {
+ let buffer = context.buffer.read(cx);
+ context.message_anchors[1].start.to_point(buffer).row
+ });
+
+ // Simulate the LLM completion
+ model
+ .as_fake()
+ .stream_last_completion_response(llm_response.to_string());
+ model.as_fake().end_last_completion_stream();
+
+ // Wait for the completion to be processed
+ cx.run_until_parked();
+
+ // Verify that the edit steps were parsed correctly
+ context.read_with(cx, |context, cx| {
+ assert_eq!(
+ workflow_steps(context, cx),
+ vec![
+ (
+ Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
+ WorkflowStepTestStatus::Pending
+ ),
+ (
+ Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
+ WorkflowStepTestStatus::Pending
+ ),
+ ]
+ );
+ });
+
+ model
+ .as_fake()
+ .respond_to_last_tool_use(tool::WorkflowStepResolutionTool {
+ step_title: "Title".into(),
+ suggestions: vec![tool::WorkflowSuggestionTool {
+ path: "/root/hello.rs".into(),
+ // Simulate a symbol name that's slightly different than our outline query
+ kind: tool::WorkflowSuggestionToolKind::Update {
+ symbol: "fn main()".into(),
+ description: "Extract a greeting function".into(),
+ },
+ }],
+ });
+
+ // Wait for tool use to be processed.
+ cx.run_until_parked();
+
+ // Verify that the first edit step is not pending anymore.
+ context.read_with(cx, |context, cx| {
+ assert_eq!(
+ workflow_steps(context, cx),
+ vec![
+ (
+ Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 12, 3),
+ WorkflowStepTestStatus::Resolved
+ ),
+ (
+ Point::new(response_start_row + 14, 0)..Point::new(response_start_row + 24, 3),
+ WorkflowStepTestStatus::Pending
+ ),
+ ]
+ );
+ });
+
+ #[derive(Copy, Clone, Debug, Eq, PartialEq)]
+ enum WorkflowStepTestStatus {
+ Pending,
+ Resolved,
+ Error,
+ }
+
+ fn workflow_steps(
+ context: &Context,
+ cx: &AppContext,
+ ) -> Vec<(Range<Point>, WorkflowStepTestStatus)> {
+ context
+ .workflow_steps
+ .iter()
+ .map(|step| {
+ let buffer = context.buffer.read(cx);
+ let status = match &step.step.read(cx).resolution {
+ None => WorkflowStepTestStatus::Pending,
+ Some(Ok(_)) => WorkflowStepTestStatus::Resolved,
+ Some(Err(_)) => WorkflowStepTestStatus::Error,
+ };
+ (step.range.to_point(buffer), status)
+ })
+ .collect()
+ }
+}
+
+#[gpui::test]
+async fn test_serialization(cx: &mut TestAppContext) {
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(LanguageModelRegistry::test);
+ cx.update(assistant_panel::init);
+ let registry = Arc::new(LanguageRegistry::test(cx.executor()));
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let context =
+ cx.new_model(|cx| Context::local(registry.clone(), None, None, prompt_builder.clone(), cx));
+ let buffer = context.read_with(cx, |context, _| context.buffer.clone());
+ let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
+ let message_1 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ let message_2 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
+ buffer.finalize_last_transaction();
+ });
+ let _message_3 = context.update(cx, |context, cx| {
+ context
+ .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
+ assert_eq!(
+ cx.read(|cx| messages(&context, cx)),
+ [
+ (message_0, Role::User, 0..2),
+ (message_1.id, Role::Assistant, 2..6),
+ (message_2.id, Role::System, 6..6),
+ ]
+ );
+
+ let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
+ let deserialized_context = cx.new_model(|cx| {
+ Context::deserialize(
+ serialized_context,
+ Default::default(),
+ registry.clone(),
+ prompt_builder.clone(),
+ None,
+ None,
+ cx,
+ )
+ });
+ let deserialized_buffer =
+ deserialized_context.read_with(cx, |context, _| context.buffer.clone());
+ assert_eq!(
+ deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
+ "a\nb\nc\n"
+ );
+ assert_eq!(
+ cx.read(|cx| messages(&deserialized_context, cx)),
+ [
+ (message_0, Role::User, 0..2),
+ (message_1.id, Role::Assistant, 2..6),
+ (message_2.id, Role::System, 6..6),
+ ]
+ );
+}
+
+#[gpui::test(iterations = 100)]
+async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
+ let min_peers = env::var("MIN_PEERS")
+ .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
+ .unwrap_or(2);
+ let max_peers = env::var("MAX_PEERS")
+ .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
+ .unwrap_or(5);
+ let operations = env::var("OPERATIONS")
+ .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
+ .unwrap_or(50);
+
+ let settings_store = cx.update(SettingsStore::test);
+ cx.set_global(settings_store);
+ cx.update(LanguageModelRegistry::test);
+
+ cx.update(assistant_panel::init);
+ let slash_commands = cx.update(SlashCommandRegistry::default_global);
+ slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
+ slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
+ slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
+
+ let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
+ let network = Arc::new(Mutex::new(Network::new(rng.clone())));
+ let mut contexts = Vec::new();
+
+ let num_peers = rng.gen_range(min_peers..=max_peers);
+ let context_id = ContextId::new();
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ for i in 0..num_peers {
+ let context = cx.new_model(|cx| {
+ Context::new(
+ context_id.clone(),
+ i as ReplicaId,
+ language::Capability::ReadWrite,
+ registry.clone(),
+ prompt_builder.clone(),
+ None,
+ None,
+ cx,
+ )
+ });
+
+ cx.update(|cx| {
+ cx.subscribe(&context, {
+ let network = network.clone();
+ move |_, event, _| {
+ if let ContextEvent::Operation(op) = event {
+ network
+ .lock()
+ .broadcast(i as ReplicaId, vec![op.to_proto()]);
+ }
+ }
+ })
+ .detach();
+ });
+
+ contexts.push(context);
+ network.lock().add_peer(i as ReplicaId);
+ }
+
+ let mut mutation_count = operations;
+
+ while mutation_count > 0
+ || !network.lock().is_idle()
+ || network.lock().contains_disconnected_peers()
+ {
+ let context_index = rng.gen_range(0..contexts.len());
+ let context = &contexts[context_index];
+
+ match rng.gen_range(0..100) {
+ 0..=29 if mutation_count > 0 => {
+ log::info!("Context {}: edit buffer", context_index);
+ context.update(cx, |context, cx| {
+ context
+ .buffer
+ .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
+ });
+ mutation_count -= 1;
+ }
+ 30..=44 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
+ log::info!("Context {}: split message at {:?}", context_index, range);
+ context.split_message(range, cx);
+ });
+ mutation_count -= 1;
+ }
+ 45..=59 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ if let Some(message) = context.messages(cx).choose(&mut rng) {
+ let role = *[Role::User, Role::Assistant, Role::System]
+ .choose(&mut rng)
+ .unwrap();
+ log::info!(
+ "Context {}: insert message after {:?} with {:?}",
+ context_index,
+ message.id,
+ role
+ );
+ context.insert_message_after(message.id, role, MessageStatus::Done, cx);
+ }
+ });
+ mutation_count -= 1;
+ }
+ 60..=74 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ let command_text = "/".to_string()
+ + slash_commands
+ .command_names()
+ .choose(&mut rng)
+ .unwrap()
+ .clone()
+ .as_ref();
+
+ let command_range = context.buffer.update(cx, |buffer, cx| {
+ let offset = buffer.random_byte_range(0, &mut rng).start;
+ buffer.edit(
+ [(offset..offset, format!("\n{}\n", command_text))],
+ None,
+ cx,
+ );
+ offset + 1..offset + 1 + command_text.len()
+ });
+
+ let output_len = rng.gen_range(1..=10);
+ let output_text = RandomCharIter::new(&mut rng)
+ .filter(|c| *c != '\r')
+ .take(output_len)
+ .collect::<String>();
+
+ let num_sections = rng.gen_range(0..=3);
+ let mut sections = Vec::with_capacity(num_sections);
+ for _ in 0..num_sections {
+ let section_start = rng.gen_range(0..output_len);
+ let section_end = rng.gen_range(section_start..=output_len);
+ sections.push(SlashCommandOutputSection {
+ range: section_start..section_end,
+ icon: ui::IconName::Ai,
+ label: "section".into(),
+ });
+ }
+
+ log::info!(
+ "Context {}: insert slash command output at {:?} with {:?}",
+ context_index,
+ command_range,
+ sections
+ );
+
+ let command_range = context.buffer.read(cx).anchor_after(command_range.start)
+ ..context.buffer.read(cx).anchor_after(command_range.end);
+ context.insert_command_output(
+ command_range,
+ Task::ready(Ok(SlashCommandOutput {
+ text: output_text,
+ sections,
+ run_commands_in_text: false,
+ })),
+ true,
+ cx,
+ );
+ });
+ cx.run_until_parked();
+ mutation_count -= 1;
+ }
+ 75..=84 if mutation_count > 0 => {
+ context.update(cx, |context, cx| {
+ if let Some(message) = context.messages(cx).choose(&mut rng) {
+ let new_status = match rng.gen_range(0..3) {
+ 0 => MessageStatus::Done,
+ 1 => MessageStatus::Pending,
+ _ => MessageStatus::Error(SharedString::from("Random error")),
+ };
+ log::info!(
+ "Context {}: update message {:?} status to {:?}",
+ context_index,
+ message.id,
+ new_status
+ );
+ context.update_metadata(message.id, cx, |metadata| {
+ metadata.status = new_status;
+ });
+ }
+ });
+ mutation_count -= 1;
+ }
+ _ => {
+ let replica_id = context_index as ReplicaId;
+ if network.lock().is_disconnected(replica_id) {
+ network.lock().reconnect_peer(replica_id, 0);
+
+ let (ops_to_send, ops_to_receive) = cx.read(|cx| {
+ let host_context = &contexts[0].read(cx);
+ let guest_context = context.read(cx);
+ (
+ guest_context.serialize_ops(&host_context.version(cx), cx),
+ host_context.serialize_ops(&guest_context.version(cx), cx),
+ )
+ });
+ let ops_to_send = ops_to_send.await;
+ let ops_to_receive = ops_to_receive
+ .await
+ .into_iter()
+ .map(ContextOperation::from_proto)
+ .collect::<Result<Vec<_>>>()
+ .unwrap();
+ log::info!(
+ "Context {}: reconnecting. Sent {} operations, received {} operations",
+ context_index,
+ ops_to_send.len(),
+ ops_to_receive.len()
+ );
+
+ network.lock().broadcast(replica_id, ops_to_send);
+ context
+ .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
+ .unwrap();
+ } else if rng.gen_bool(0.1) && replica_id != 0 {
+ log::info!("Context {}: disconnecting", context_index);
+ network.lock().disconnect_peer(replica_id);
+ } else if network.lock().has_unreceived(replica_id) {
+ log::info!("Context {}: applying operations", context_index);
+ let ops = network.lock().receive(replica_id);
+ let ops = ops
+ .into_iter()
+ .map(ContextOperation::from_proto)
+ .collect::<Result<Vec<_>>>()
+ .unwrap();
+ context
+ .update(cx, |context, cx| context.apply_ops(ops, cx))
+ .unwrap();
+ }
+ }
+ }
+ }
+
+ cx.read(|cx| {
+ let first_context = contexts[0].read(cx);
+ for context in &contexts[1..] {
+ let context = context.read(cx);
+ assert!(context.pending_ops.is_empty());
+ assert_eq!(
+ context.buffer.read(cx).text(),
+ first_context.buffer.read(cx).text(),
+ "Context {} text != Context 0 text",
+ context.buffer.read(cx).replica_id()
+ );
+ assert_eq!(
+ context.message_anchors,
+ first_context.message_anchors,
+ "Context {} messages != Context 0 messages",
+ context.buffer.read(cx).replica_id()
+ );
+ assert_eq!(
+ context.messages_metadata,
+ first_context.messages_metadata,
+ "Context {} message metadata != Context 0 message metadata",
+ context.buffer.read(cx).replica_id()
+ );
+ assert_eq!(
+ context.slash_command_output_sections,
+ first_context.slash_command_output_sections,
+ "Context {} slash command output sections != Context 0 slash command output sections",
+ context.buffer.read(cx).replica_id()
+ );
+ }
+ });
+}
+
+fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
+ context
+ .read(cx)
+ .messages(cx)
+ .map(|message| (message.id, message.role, message.offset_range))
+ .collect()
+}
+
+#[derive(Clone)]
+struct FakeSlashCommand(String);
+
+impl SlashCommand for FakeSlashCommand {
+ fn name(&self) -> String {
+ self.0.clone()
+ }
+
+ fn description(&self) -> String {
+ format!("Fake slash command: {}", self.0)
+ }
+
+ fn menu_text(&self) -> String {
+ format!("Run fake command: {}", self.0)
+ }
+
+ fn complete_argument(
+ self: Arc<Self>,
+ _arguments: &[String],
+ _cancel: Arc<AtomicBool>,
+ _workspace: Option<WeakView<Workspace>>,
+ _cx: &mut WindowContext,
+ ) -> Task<Result<Vec<ArgumentCompletion>>> {
+ Task::ready(Ok(vec![]))
+ }
+
+ fn requires_argument(&self) -> bool {
+ false
+ }
+
+ fn run(
+ self: Arc<Self>,
+ _arguments: &[String],
+ _workspace: WeakView<Workspace>,
+ _delegate: Option<Arc<dyn LspAdapterDelegate>>,
+ _cx: &mut WindowContext,
+ ) -> Task<Result<SlashCommandOutput>> {
+ Task::ready(Ok(SlashCommandOutput {
+ text: format!("Executed fake command: {}", self.0),
+ sections: vec![],
+ run_commands_in_text: false,
+ }))
+ }
+}
@@ -1,223 +0,0 @@
-use std::{ops::Range, sync::Arc};
-
-use collections::{HashMap, HashSet};
-use editor::{
- display_map::{BlockDisposition, BlockProperties, BlockStyle, CustomBlockId},
- Editor,
-};
-use gpui::{AppContext, Model, View};
-use text::{Bias, ToOffset, ToPoint};
-use ui::{
- div, h_flex, px, Color, Element as _, ParentElement as _, Styled, ViewContext, WindowContext,
-};
-
-use crate::{Context, ResolvedWorkflowStep, WorkflowSuggestion};
-
-type StepRange = Range<language::Anchor>;
-
-struct DebugInfo {
- range: Range<editor::Anchor>,
- block_id: CustomBlockId,
-}
-
-pub(crate) struct ContextInspector {
- active_debug_views: HashMap<Range<language::Anchor>, DebugInfo>,
- context: Model<Context>,
- editor: View<Editor>,
-}
-
-impl ContextInspector {
- pub(crate) fn new(editor: View<Editor>, context: Model<Context>) -> Self {
- Self {
- editor,
- context,
- active_debug_views: Default::default(),
- }
- }
-
- pub(crate) fn is_active(&self, range: &StepRange) -> bool {
- self.active_debug_views.contains_key(range)
- }
-
- pub(crate) fn refresh(&mut self, range: &StepRange, cx: &mut WindowContext<'_>) {
- if self.deactivate_for(range, cx) {
- self.activate_for_step(range.clone(), cx);
- }
- }
- fn crease_content(
- context: &Model<Context>,
- range: StepRange,
- cx: &mut AppContext,
- ) -> Option<Arc<str>> {
- use std::fmt::Write;
- let step = context.read(cx).workflow_step_for_range(range)?;
- let mut output = String::from("\n\n");
- match &step.status {
- crate::WorkflowStepStatus::Resolved(ResolvedWorkflowStep { title, suggestions }) => {
- writeln!(output, "Resolution:").ok()?;
- writeln!(output, " {title:?}").ok()?;
- if suggestions.is_empty() {
- writeln!(output, " No suggestions").ok()?;
- }
-
- for (buffer, suggestion_groups) in suggestions {
- let buffer = buffer.read(cx);
- let buffer_path = buffer
- .file()
- .and_then(|file| file.path().to_str())
- .unwrap_or("untitled");
- let snapshot = buffer.text_snapshot();
- writeln!(output, " {buffer_path}:").ok()?;
- for group in suggestion_groups {
- for suggestion in &group.suggestions {
- pretty_print_workflow_suggestion(&mut output, suggestion, &snapshot);
- }
- }
- }
- }
- crate::WorkflowStepStatus::Pending(_) => {
- writeln!(output, "Resolution: Pending").ok()?;
- }
- crate::WorkflowStepStatus::Error(error) => {
- writeln!(output, "Resolution: Error").ok()?;
- writeln!(output, "{error:?}").ok()?;
- }
- }
-
- Some(output.into())
- }
- pub(crate) fn activate_for_step(&mut self, range: StepRange, cx: &mut WindowContext<'_>) {
- let text = Self::crease_content(&self.context, range.clone(), cx)
- .unwrap_or_else(|| Arc::from("Error fetching debug info"));
- self.editor.update(cx, |editor, cx| {
- let buffer = editor.buffer().read(cx).as_singleton()?;
- let snapshot = buffer.read(cx).text_snapshot();
- let start_offset = range.end.to_offset(&snapshot) + 1;
- let start_offset = snapshot.clip_offset(start_offset, Bias::Right);
- let text_len = text.len();
- buffer.update(cx, |this, cx| {
- this.edit([(start_offset..start_offset, text)], None, cx);
- });
-
- let end_offset = start_offset + text_len;
- let multibuffer_snapshot = editor.buffer().read(cx).snapshot(cx);
- let anchor_before = multibuffer_snapshot.anchor_after(start_offset);
- let anchor_after = multibuffer_snapshot.anchor_before(end_offset);
-
- let block_id = editor
- .insert_blocks(
- [BlockProperties {
- position: anchor_after,
- height: 0,
- style: BlockStyle::Sticky,
- render: Box::new(move |cx| {
- div()
- .w_full()
- .px(cx.gutter_dimensions.full_width())
- .child(h_flex().h(px(1.)).bg(Color::Warning.color(cx)))
- .into_any()
- }),
- disposition: BlockDisposition::Below,
- priority: 0,
- }],
- None,
- cx,
- )
- .into_iter()
- .next()?;
- let info = DebugInfo {
- range: anchor_before..anchor_after,
- block_id,
- };
- self.active_debug_views.insert(range, info);
- Some(())
- });
- }
-
- fn deactivate_impl(editor: &mut Editor, debug_data: DebugInfo, cx: &mut ViewContext<Editor>) {
- editor.remove_blocks(HashSet::from_iter([debug_data.block_id]), None, cx);
- editor.edit([(debug_data.range, Arc::<str>::default())], cx)
- }
- pub(crate) fn deactivate_for(&mut self, range: &StepRange, cx: &mut WindowContext<'_>) -> bool {
- if let Some(debug_data) = self.active_debug_views.remove(range) {
- self.editor.update(cx, |this, cx| {
- Self::deactivate_impl(this, debug_data, cx);
- });
- true
- } else {
- false
- }
- }
-
- pub(crate) fn deactivate(&mut self, cx: &mut WindowContext<'_>) {
- let steps_to_disable = std::mem::take(&mut self.active_debug_views);
-
- self.editor.update(cx, move |editor, cx| {
- for (_, debug_data) in steps_to_disable {
- Self::deactivate_impl(editor, debug_data, cx);
- }
- });
- }
-}
-fn pretty_print_anchor(
- out: &mut String,
- anchor: &language::Anchor,
- snapshot: &text::BufferSnapshot,
-) {
- use std::fmt::Write;
- let point = anchor.to_point(snapshot);
- write!(out, "{}:{}", point.row, point.column).ok();
-}
-fn pretty_print_range(
- out: &mut String,
- range: &Range<language::Anchor>,
- snapshot: &text::BufferSnapshot,
-) {
- use std::fmt::Write;
- write!(out, " Range: ").ok();
- pretty_print_anchor(out, &range.start, snapshot);
- write!(out, "..").ok();
- pretty_print_anchor(out, &range.end, snapshot);
-}
-
-fn pretty_print_workflow_suggestion(
- out: &mut String,
- suggestion: &WorkflowSuggestion,
- snapshot: &text::BufferSnapshot,
-) {
- use std::fmt::Write;
- let (range, description, position) = match suggestion {
- WorkflowSuggestion::Update { range, description } => (Some(range), Some(description), None),
- WorkflowSuggestion::CreateFile { description } => (None, Some(description), None),
- WorkflowSuggestion::AppendChild {
- position,
- description,
- }
- | WorkflowSuggestion::InsertSiblingBefore {
- position,
- description,
- }
- | WorkflowSuggestion::InsertSiblingAfter {
- position,
- description,
- }
- | WorkflowSuggestion::PrependChild {
- position,
- description,
- } => (None, Some(description), Some(position)),
-
- WorkflowSuggestion::Delete { range } => (Some(range), None, None),
- };
- if let Some(description) = description {
- writeln!(out, " Description: {description}").ok();
- }
- if let Some(range) = range {
- pretty_print_range(out, range, snapshot);
- }
- if let Some(position) = position {
- write!(out, " Position: ").ok();
- pretty_print_anchor(out, position, snapshot);
- write!(out, "\n").ok();
- }
- write!(out, "\n").ok();
-}
@@ -19,6 +19,7 @@ use fs::Fs;
use futures::{
channel::mpsc,
future::{BoxFuture, LocalBoxFuture},
+ join,
stream::{self, BoxStream},
SinkExt, Stream, StreamExt,
};
@@ -45,6 +46,7 @@ use std::{
task::{self, Poll},
time::{Duration, Instant},
};
+use text::OffsetRangeExt as _;
use theme::ThemeSettings;
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt};
@@ -1107,6 +1109,7 @@ impl InlineAssistant {
editor.set_show_gutter(false, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_read_only(true);
+ editor.set_show_inline_completions(false);
editor.highlight_rows::<DeletedLines>(
Anchor::min()..=Anchor::max(),
Some(cx.theme().status().deleted_background),
@@ -1280,12 +1283,6 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
})
}
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
-pub enum InitialInsertion {
- NewlineBefore,
- NewlineAfter,
-}
-
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize);
@@ -2360,6 +2357,15 @@ impl Codegen {
return Err(anyhow::anyhow!("invalid transformation range"));
};
+ let mut transform_context_range = transform_range.to_point(&transform_buffer);
+ transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
+ transform_context_range.start.column = 0;
+ transform_context_range.end =
+ (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
+ transform_context_range.end.column =
+ transform_buffer.line_len(transform_context_range.end.row);
+ let transform_context_range = transform_context_range.to_offset(&transform_buffer);
+
let selected_ranges = self
.selected_ranges
.iter()
@@ -2382,6 +2388,7 @@ impl Codegen {
transform_buffer,
transform_range,
selected_ranges,
+ transform_context_range,
)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
@@ -2393,6 +2400,7 @@ impl Codegen {
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
+ cache: false,
});
Ok(LanguageModelRequest {
@@ -2439,12 +2447,12 @@ impl Codegen {
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
let mut edit_start = edit_range.start.to_offset(&snapshot);
- self.generation = cx.spawn(|this, mut cx| {
+ self.generation = cx.spawn(|codegen, mut cx| {
async move {
let chunks = stream.await;
let generate = async {
let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
- let diff: Task<anyhow::Result<()>> =
+ let line_based_stream_diff: Task<anyhow::Result<()>> =
cx.background_executor().spawn(async move {
let mut response_latency = None;
let request_start = Instant::now();
@@ -2495,10 +2503,10 @@ impl Codegen {
});
while let Some((char_ops, line_diff)) = diff_rx.next().await {
- this.update(&mut cx, |this, cx| {
- this.last_equal_ranges.clear();
+ codegen.update(&mut cx, |codegen, cx| {
+ codegen.last_equal_ranges.clear();
- let transaction = this.buffer.update(cx, |buffer, cx| {
+ let transaction = codegen.buffer.update(cx, |buffer, cx| {
// Avoid grouping assistant edits with user edits.
buffer.finalize_last_transaction(cx);
@@ -2523,23 +2531,24 @@ impl Codegen {
let edit_range = snapshot.anchor_after(edit_start)
..snapshot.anchor_before(edit_end);
edit_start = edit_end;
- this.last_equal_ranges.push(edit_range);
+ codegen.last_equal_ranges.push(edit_range);
None
}
}),
None,
cx,
);
- this.edit_position = Some(snapshot.anchor_after(edit_start));
+ codegen.edit_position = Some(snapshot.anchor_after(edit_start));
buffer.end_transaction(cx)
});
if let Some(transaction) = transaction {
- if let Some(first_transaction) = this.transformation_transaction_id
+ if let Some(first_transaction) =
+ codegen.transformation_transaction_id
{
// Group all assistant edits into the first transaction.
- this.buffer.update(cx, |buffer, cx| {
+ codegen.buffer.update(cx, |buffer, cx| {
buffer.merge_transactions(
transaction,
first_transaction,
@@ -2547,36 +2556,45 @@ impl Codegen {
)
});
} else {
- this.transformation_transaction_id = Some(transaction);
- this.buffer.update(cx, |buffer, cx| {
+ codegen.transformation_transaction_id = Some(transaction);
+ codegen.buffer.update(cx, |buffer, cx| {
buffer.finalize_last_transaction(cx)
});
}
}
- this.update_diff(edit_range.clone(), line_diff, cx);
+ codegen.reapply_line_based_diff(edit_range.clone(), line_diff, cx);
cx.notify();
})?;
}
- diff.await?;
+ // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
+ // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
+ // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
+ let batch_diff_task = codegen.update(&mut cx, |codegen, cx| {
+ codegen.reapply_batch_diff(edit_range.clone(), cx)
+ })?;
+ let (line_based_stream_diff, ()) =
+ join!(line_based_stream_diff, batch_diff_task);
+ line_based_stream_diff?;
anyhow::Ok(())
};
let result = generate.await;
- this.update(&mut cx, |this, cx| {
- this.last_equal_ranges.clear();
- if let Err(error) = result {
- this.status = CodegenStatus::Error(error);
- } else {
- this.status = CodegenStatus::Done;
- }
- cx.emit(CodegenEvent::Finished);
- cx.notify();
- })
- .ok();
+ codegen
+ .update(&mut cx, |this, cx| {
+ this.last_equal_ranges.clear();
+ if let Err(error) = result {
+ this.status = CodegenStatus::Error(error);
+ } else {
+ this.status = CodegenStatus::Done;
+ }
+ cx.emit(CodegenEvent::Finished);
+ cx.notify();
+ })
+ .ok();
}
});
cx.notify();
@@ -2608,7 +2626,7 @@ impl Codegen {
});
}
- fn update_diff(
+ fn reapply_line_based_diff(
&mut self,
edit_range: Range<Anchor>,
line_operations: Vec<LineOperation>,
@@ -2667,6 +2685,99 @@ impl Codegen {
cx.notify();
}
}
+
+ fn reapply_batch_diff(
+ &mut self,
+ edit_range: Range<Anchor>,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<()> {
+ let old_snapshot = self.snapshot.clone();
+ let old_range = edit_range.to_point(&old_snapshot);
+ let new_snapshot = self.buffer.read(cx).snapshot(cx);
+ let new_range = edit_range.to_point(&new_snapshot);
+
+ cx.spawn(|codegen, mut cx| async move {
+ let (deleted_row_ranges, inserted_row_ranges) = cx
+ .background_executor()
+ .spawn(async move {
+ let old_text = old_snapshot
+ .text_for_range(
+ Point::new(old_range.start.row, 0)
+ ..Point::new(
+ old_range.end.row,
+ old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
+ ),
+ )
+ .collect::<String>();
+ let new_text = new_snapshot
+ .text_for_range(
+ Point::new(new_range.start.row, 0)
+ ..Point::new(
+ new_range.end.row,
+ new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
+ ),
+ )
+ .collect::<String>();
+
+ let mut old_row = old_range.start.row;
+ let mut new_row = new_range.start.row;
+ let batch_diff =
+ similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
+
+ let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
+ let mut inserted_row_ranges = Vec::new();
+ for change in batch_diff.iter_all_changes() {
+ let line_count = change.value().lines().count() as u32;
+ match change.tag() {
+ similar::ChangeTag::Equal => {
+ old_row += line_count;
+ new_row += line_count;
+ }
+ similar::ChangeTag::Delete => {
+ let old_end_row = old_row + line_count - 1;
+ let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
+
+ if let Some((_, last_deleted_row_range)) =
+ deleted_row_ranges.last_mut()
+ {
+ if *last_deleted_row_range.end() + 1 == old_row {
+ *last_deleted_row_range =
+ *last_deleted_row_range.start()..=old_end_row;
+ } else {
+ deleted_row_ranges.push((new_row, old_row..=old_end_row));
+ }
+ } else {
+ deleted_row_ranges.push((new_row, old_row..=old_end_row));
+ }
+
+ old_row += line_count;
+ }
+ similar::ChangeTag::Insert => {
+ let new_end_row = new_row + line_count - 1;
+ let start = new_snapshot.anchor_before(Point::new(new_row, 0));
+ let end = new_snapshot.anchor_before(Point::new(
+ new_end_row,
+ new_snapshot.line_len(MultiBufferRow(new_end_row)),
+ ));
+ inserted_row_ranges.push(start..=end);
+ new_row += line_count;
+ }
+ }
+ }
+
+ (deleted_row_ranges, inserted_row_ranges)
+ })
+ .await;
+
+ codegen
+ .update(&mut cx, |codegen, cx| {
+ codegen.diff.deleted_row_ranges = deleted_row_ranges;
+ codegen.diff.inserted_row_ranges = inserted_row_ranges;
+ cx.notify();
+ })
+ .ok();
+ })
+ }
}
struct StripInvalidSpans<T> {
@@ -295,5 +295,6 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
.menu(move |_cx| Some(picker_view.clone()))
.trigger(self.trigger)
.attach(gpui::AnchorCorner::BottomLeft)
+ .when_some(self.handle, |menu, handle| menu.with_handle(handle))
}
}
@@ -11,8 +11,8 @@ use futures::{
};
use fuzzy::StringMatchCandidate;
use gpui::{
- actions, point, size, transparent_black, AppContext, BackgroundExecutor, Bounds, EventEmitter,
- Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle,
+ actions, point, size, transparent_black, Action, AppContext, BackgroundExecutor, Bounds,
+ EventEmitter, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle,
TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions,
};
use heed::{
@@ -38,7 +38,7 @@ use std::{
use text::LineEnding;
use theme::ThemeSettings;
use ui::{
- div, prelude::*, IconButtonShape, ListItem, ListItemSpacing, ParentElement, Render,
+ div, prelude::*, IconButtonShape, KeyBinding, ListItem, ListItemSpacing, ParentElement, Render,
SharedString, Styled, Tooltip, ViewContext, VisualContext,
};
use util::{ResultExt, TryFutureExt};
@@ -100,7 +100,7 @@ pub fn open_prompt_library(
WindowOptions {
titlebar: Some(TitlebarOptions {
title: Some("Prompt Library".into()),
- appears_transparent: true,
+ appears_transparent: !cfg!(windows),
traffic_light_position: Some(point(px(9.0), px(9.0))),
}),
window_bounds: Some(WindowBounds::Windowed(bounds)),
@@ -155,6 +155,14 @@ impl PickerDelegate for PromptPickerDelegate {
self.matches.len()
}
+ fn no_matches_text(&self, _cx: &mut WindowContext) -> SharedString {
+ if self.store.prompt_count() == 0 {
+ "No prompts.".into()
+ } else {
+ "No prompts found matching your search.".into()
+ }
+ }
+
fn selected_index(&self) -> usize {
self.selected_index
}
@@ -486,7 +494,10 @@ impl PromptLibrary {
let mut editor = Editor::auto_width(cx);
editor.set_placeholder_text("Untitled", cx);
editor.set_text(prompt_metadata.title.unwrap_or_default(), cx);
- editor.set_read_only(prompt_id.is_built_in());
+ if prompt_id.is_built_in() {
+ editor.set_read_only(true);
+ editor.set_show_inline_completions(false);
+ }
editor
});
let body_editor = cx.new_view(|cx| {
@@ -498,7 +509,10 @@ impl PromptLibrary {
});
let mut editor = Editor::for_buffer(buffer, None, cx);
- editor.set_read_only(prompt_id.is_built_in());
+ if prompt_id.is_built_in() {
+ editor.set_read_only(true);
+ editor.set_show_inline_completions(false);
+ }
editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
editor.set_show_gutter(false, cx);
editor.set_show_wrap_guides(false, cx);
@@ -776,6 +790,7 @@ impl PromptLibrary {
messages: vec![LanguageModelRequestMessage {
role: Role::System,
content: vec![body.to_string().into()],
+ cache: false,
}],
stop: Vec::new(),
temperature: 1.,
@@ -1094,7 +1109,55 @@ impl Render for PromptLibrary {
.font(ui_font)
.text_color(theme.colors().text)
.child(self.render_prompt_list(cx))
- .child(self.render_active_prompt(cx))
+ .map(|el| {
+ if self.store.prompt_count() == 0 {
+ el.child(
+ v_flex()
+ .w_2_3()
+ .h_full()
+ .items_center()
+ .justify_center()
+ .gap_4()
+ .bg(cx.theme().colors().editor_background)
+ .child(
+ h_flex()
+ .gap_2()
+ .child(
+ Icon::new(IconName::Book)
+ .size(IconSize::Medium)
+ .color(Color::Muted),
+ )
+ .child(
+ Label::new("No prompts yet")
+ .size(LabelSize::Large)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ h_flex()
+ .child(h_flex())
+ .child(
+ v_flex()
+ .gap_1()
+ .child(Label::new("Create your first prompt:"))
+ .child(
+ Button::new("create-prompt", "New Prompt")
+ .full_width()
+ .key_binding(KeyBinding::for_action(
+ &NewPrompt, cx,
+ ))
+ .on_click(|_, cx| {
+ cx.dispatch_action(NewPrompt.boxed_clone())
+ }),
+ ),
+ )
+ .child(h_flex()),
+ ),
+ )
+ } else {
+ el.child(self.render_active_prompt(cx))
+ }
+ })
}
}
@@ -1342,6 +1405,11 @@ impl PromptStore {
})
}
+ /// Returns the number of prompts in the store.
+ fn prompt_count(&self) -> usize {
+ self.metadata_cache.read().metadata.len()
+ }
+
fn metadata(&self, id: PromptId) -> Option<PromptMetadata> {
self.metadata_cache.read().metadata_by_id.get(&id).cloned()
}
@@ -16,7 +16,9 @@ pub struct ContentPromptContext {
pub document_content: String,
pub user_prompt: String,
pub rewrite_section: String,
- pub rewrite_section_with_selections: String,
+ pub rewrite_section_prefix: String,
+ pub rewrite_section_suffix: String,
+ pub rewrite_section_with_edits: String,
pub has_insertion: bool,
pub has_replacement: bool,
}
@@ -31,6 +33,15 @@ pub struct TerminalAssistantPromptContext {
pub user_prompt: String,
}
+/// Context required to generate a workflow step resolution prompt.
+#[derive(Debug, Serialize)]
+pub struct StepResolutionContext {
+ /// The full context, including <step>...</step> tags
+ pub workflow_context: String,
+ /// The text of the specific step from the context to resolve
+ pub step_to_resolve: String,
+}
+
pub struct PromptBuilder {
handlebars: Arc<Mutex<Handlebars<'static>>>,
}
@@ -164,6 +175,7 @@ impl PromptBuilder {
buffer: BufferSnapshot,
transform_range: Range<usize>,
selected_ranges: Vec<Range<usize>>,
+ transform_context_range: Range<usize>,
) -> Result<String, RenderError> {
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text",
@@ -193,6 +205,7 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
+
document_content.push_str("<rewrite_this>\n");
for chunk in buffer.text_for_range(transform_range.clone()) {
document_content.push_str(chunk);
@@ -208,7 +221,17 @@ impl PromptBuilder {
rewrite_section.push_str(chunk);
}
- let rewrite_section_with_selections = {
+ let mut rewrite_section_prefix = String::new();
+ for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
+ rewrite_section_prefix.push_str(chunk);
+ }
+
+ let mut rewrite_section_suffix = String::new();
+ for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
+ rewrite_section_suffix.push_str(chunk);
+ }
+
+ let rewrite_section_with_edits = {
let mut section_with_selections = String::new();
let mut last_end = 0;
for selected_range in &selected_ranges {
@@ -245,7 +268,9 @@ impl PromptBuilder {
document_content,
user_prompt,
rewrite_section,
- rewrite_section_with_selections,
+ rewrite_section_prefix,
+ rewrite_section_suffix,
+ rewrite_section_with_edits,
has_insertion,
has_replacement,
};
@@ -278,7 +303,10 @@ impl PromptBuilder {
self.handlebars.lock().render("edit_workflow", &())
}
- pub fn generate_step_resolution_prompt(&self) -> Result<String, RenderError> {
- self.handlebars.lock().render("step_resolution", &())
+ pub fn generate_step_resolution_prompt(
+ &self,
+ context: &StepResolutionContext,
+ ) -> Result<String, RenderError> {
+ self.handlebars.lock().render("step_resolution", context)
}
}
@@ -1,5 +1,6 @@
use crate::assistant_panel::ContextEditor;
use anyhow::Result;
+use assistant_slash_command::AfterCompletion;
pub use assistant_slash_command::{SlashCommand, SlashCommandOutput, SlashCommandRegistry};
use editor::{CompletionProvider, Editor};
use fuzzy::{match_strings, StringMatchCandidate};
@@ -18,6 +19,7 @@ use std::{
use ui::ActiveTheme;
use workspace::Workspace;
+pub mod context_server_command;
pub mod default_command;
pub mod diagnostics_command;
pub mod docs_command;
@@ -96,20 +98,25 @@ impl SlashCommandCompletionProvider {
let command = commands.command(&mat.string)?;
let mut new_text = mat.string.clone();
let requires_argument = command.requires_argument();
- if requires_argument {
+ let accepts_arguments = command.accepts_arguments();
+ if requires_argument || accepts_arguments {
new_text.push(' ');
}
- let confirm = editor.clone().zip(workspace.clone()).and_then(
- |(editor, workspace)| {
- (!requires_argument).then(|| {
+ let confirm =
+ editor
+ .clone()
+ .zip(workspace.clone())
+ .map(|(editor, workspace)| {
let command_name = mat.string.clone();
let command_range = command_range.clone();
let editor = editor.clone();
let workspace = workspace.clone();
Arc::new(
move |intent: CompletionIntent, cx: &mut WindowContext| {
- if intent.is_complete() {
+ if !requires_argument
+ && (!accepts_arguments || intent.is_complete())
+ {
editor
.update(cx, |editor, cx| {
editor.run_command(
@@ -122,12 +129,13 @@ impl SlashCommandCompletionProvider {
);
})
.ok();
+ false
+ } else {
+ requires_argument || accepts_arguments
}
},
) as Arc<_>
- })
- },
- );
+ });
Some(project::Completion {
old_range: name_range.clone(),
documentation: Some(Documentation::SingleLine(command.description())),
@@ -135,7 +143,6 @@ impl SlashCommandCompletionProvider {
label: command.label(cx),
server_id: LanguageServerId(0),
lsp_completion: Default::default(),
- show_new_completions_on_confirm: requires_argument,
confirm,
})
})
@@ -174,7 +181,7 @@ impl SlashCommandCompletionProvider {
.await?
.into_iter()
.map(|new_argument| {
- let confirm = if new_argument.run_command {
+ let confirm =
editor
.clone()
.zip(workspace.clone())
@@ -191,7 +198,9 @@ impl SlashCommandCompletionProvider {
let command_range = command_range.clone();
let command_name = command_name.clone();
move |intent: CompletionIntent, cx: &mut WindowContext| {
- if intent.is_complete() {
+ if new_argument.after_completion.run()
+ || intent.is_complete()
+ {
editor
.update(cx, |editor, cx| {
editor.run_command(
@@ -204,16 +213,16 @@ impl SlashCommandCompletionProvider {
);
})
.ok();
+ false
+ } else {
+ !new_argument.after_completion.run()
}
}
}) as Arc<_>
- })
- } else {
- None
- };
+ });
let mut new_text = new_argument.new_text.clone();
- if !new_argument.run_command {
+ if new_argument.after_completion == AfterCompletion::Continue {
new_text.push(' ');
}
@@ -228,7 +237,6 @@ impl SlashCommandCompletionProvider {
documentation: None,
server_id: LanguageServerId(0),
lsp_completion: Default::default(),
- show_new_completions_on_confirm: !new_argument.run_command,
confirm,
}
})
@@ -0,0 +1,125 @@
+use anyhow::{anyhow, Result};
+use assistant_slash_command::{
+ ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
+};
+use collections::HashMap;
+use context_servers::{
+ manager::{ContextServer, ContextServerManager},
+ protocol::PromptInfo,
+};
+use gpui::{Task, WeakView, WindowContext};
+use language::LspAdapterDelegate;
+use std::sync::atomic::AtomicBool;
+use std::sync::Arc;
+use ui::{IconName, SharedString};
+use workspace::Workspace;
+
+pub struct ContextServerSlashCommand {
+ server_id: String,
+ prompt: PromptInfo,
+}
+
+impl ContextServerSlashCommand {
+ pub fn new(server: &Arc<ContextServer>, prompt: PromptInfo) -> Self {
+ Self {
+ server_id: server.id.clone(),
+ prompt,
+ }
+ }
+}
+
+impl SlashCommand for ContextServerSlashCommand {
+ fn name(&self) -> String {
+ self.prompt.name.clone()
+ }
+
+ fn description(&self) -> String {
+ format!("Run context server command: {}", self.prompt.name)
+ }
+
+ fn menu_text(&self) -> String {
+ format!("Run '{}' from {}", self.prompt.name, self.server_id)
+ }
+
+ fn requires_argument(&self) -> bool {
+ self.prompt
+ .arguments
+ .as_ref()
+ .map_or(false, |args| !args.is_empty())
+ }
+
+ fn complete_argument(
+ self: Arc<Self>,
+ _arguments: &[String],
+ _cancel: Arc<AtomicBool>,
+ _workspace: Option<WeakView<Workspace>>,
+ _cx: &mut WindowContext,
+ ) -> Task<Result<Vec<ArgumentCompletion>>> {
+ Task::ready(Ok(Vec::new()))
+ }
+
+ fn run(
+ self: Arc<Self>,
+ arguments: &[String],
+ _workspace: WeakView<Workspace>,
+ _delegate: Option<Arc<dyn LspAdapterDelegate>>,
+ cx: &mut WindowContext,
+ ) -> Task<Result<SlashCommandOutput>> {
+ let server_id = self.server_id.clone();
+ let prompt_name = self.prompt.name.clone();
+ let argument = arguments.first().cloned();
+
+ let manager = ContextServerManager::global(cx);
+ let manager = manager.read(cx);
+ if let Some(server) = manager.get_server(&server_id) {
+ cx.foreground_executor().spawn(async move {
+ let Some(protocol) = server.client.read().clone() else {
+ return Err(anyhow!("Context server not initialized"));
+ };
+
+ let result = protocol
+ .run_prompt(&prompt_name, prompt_arguments(&self.prompt, argument)?)
+ .await?;
+
+ Ok(SlashCommandOutput {
+ sections: vec![SlashCommandOutputSection {
+ range: 0..result.len(),
+ icon: IconName::ZedAssistant,
+ label: SharedString::from(format!("Result from {}", prompt_name)),
+ }],
+ text: result,
+ run_commands_in_text: false,
+ })
+ })
+ } else {
+ Task::ready(Err(anyhow!("Context server not found")))
+ }
+ }
+}
+
+fn prompt_arguments(
+ prompt: &PromptInfo,
+ argument: Option<String>,
+) -> Result<HashMap<String, String>> {
+ match &prompt.arguments {
+ Some(args) if args.len() >= 2 => Err(anyhow!(
+ "Prompt has more than one argument, which is not supported"
+ )),
+ Some(args) if args.len() == 1 => match argument {
+ Some(value) => Ok(HashMap::from_iter([(args[0].name.clone(), value)])),
+ None => Err(anyhow!("Prompt expects argument but none given")),
+ },
+ Some(_) | None => Ok(HashMap::default()),
+ }
+}
+
+/// MCP servers can return prompts with multiple arguments. Since we only
+/// support one argument, we ignore all others. This is the necessary predicate
+/// for this.
+pub fn acceptable_prompt(prompt: &PromptInfo) -> bool {
+ match &prompt.arguments {
+ None => true,
+ Some(args) if args.len() == 1 => true,
+ _ => false,
+ }
+}
@@ -103,6 +103,10 @@ impl SlashCommand for DiagnosticsSlashCommand {
false
}
+ fn accepts_arguments(&self) -> bool {
+ true
+ }
+
fn complete_argument(
self: Arc<Self>,
arguments: &[String],
@@ -149,7 +153,7 @@ impl SlashCommand for DiagnosticsSlashCommand {
.map(|completion| ArgumentCompletion {
label: completion.clone().into(),
new_text: completion,
- run_command: true,
+ after_completion: assistant_slash_command::AfterCompletion::Run,
replace_previous_arguments: false,
})
.collect())
@@ -181,7 +181,7 @@ impl SlashCommand for DocsSlashCommand {
.map(|item| ArgumentCompletion {
label: item.clone().into(),
new_text: item.to_string(),
- run_command: true,
+ after_completion: assistant_slash_command::AfterCompletion::Run,
replace_previous_arguments: false,
})
.collect()
@@ -194,7 +194,7 @@ impl SlashCommand for DocsSlashCommand {
return Ok(vec![ArgumentCompletion {
label: "No available docs providers.".into(),
new_text: String::new(),
- run_command: false,
+ after_completion: false.into(),
replace_previous_arguments: false,
}]);
}
@@ -204,7 +204,7 @@ impl SlashCommand for DocsSlashCommand {
.map(|provider| ArgumentCompletion {
label: provider.to_string().into(),
new_text: provider.to_string(),
- run_command: false,
+ after_completion: false.into(),
replace_previous_arguments: false,
})
.collect())
@@ -236,7 +236,7 @@ impl SlashCommand for DocsSlashCommand {
.map(|package_name| ArgumentCompletion {
label: format!("{package_name} (unindexed)").into(),
new_text: format!("{package_name}"),
- run_command: true,
+ after_completion: true.into(),
replace_previous_arguments: false,
})
.collect::<Vec<_>>();
@@ -250,7 +250,7 @@ impl SlashCommand for DocsSlashCommand {
)
.into(),
new_text: provider.to_string(),
- run_command: false,
+ after_completion: false.into(),
replace_previous_arguments: false,
}]);
}
@@ -1,6 +1,6 @@
use super::{diagnostics_command::write_single_file_diagnostics, SlashCommand, SlashCommandOutput};
use anyhow::{anyhow, Context as _, Result};
-use assistant_slash_command::{ArgumentCompletion, SlashCommandOutputSection};
+use assistant_slash_command::{AfterCompletion, ArgumentCompletion, SlashCommandOutputSection};
use fuzzy::PathMatch;
use gpui::{AppContext, Model, Task, View, WeakView};
use language::{BufferSnapshot, CodeLabel, HighlightId, LineEnding, LspAdapterDelegate};
@@ -12,7 +12,7 @@ use std::{
sync::{atomic::AtomicBool, Arc},
};
use ui::prelude::*;
-use util::{paths::PathMatcher, ResultExt};
+use util::ResultExt;
use workspace::Workspace;
pub(crate) struct FileSlashCommand;
@@ -164,7 +164,11 @@ impl SlashCommand for FileSlashCommand {
Some(ArgumentCompletion {
label,
new_text: text,
- run_command: true,
+ after_completion: if path_match.is_dir {
+ AfterCompletion::Compose
+ } else {
+ AfterCompletion::Run
+ },
replace_previous_arguments: false,
})
})
@@ -190,16 +194,17 @@ impl SlashCommand for FileSlashCommand {
let task = collect_files(workspace.read(cx).project().clone(), arguments, cx);
cx.foreground_executor().spawn(async move {
- let (text, ranges) = task.await?;
+ let output = task.await?;
Ok(SlashCommandOutput {
- text,
- sections: ranges
+ text: output.completion_text,
+ sections: output
+ .files
.into_iter()
- .map(|(range, path, entry_type)| {
+ .map(|file| {
build_entry_output_section(
- range,
- Some(&path),
- entry_type == EntryType::Directory,
+ file.range_in_text,
+ Some(&file.path),
+ file.entry_type == EntryType::Directory,
None,
)
})
@@ -210,24 +215,37 @@ impl SlashCommand for FileSlashCommand {
}
}
-#[derive(Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, PartialEq, Debug)]
enum EntryType {
File,
Directory,
}
+#[derive(Clone, PartialEq, Debug)]
+struct FileCommandOutput {
+ completion_text: String,
+ files: Vec<OutputFile>,
+}
+
+#[derive(Clone, PartialEq, Debug)]
+struct OutputFile {
+ range_in_text: Range<usize>,
+ path: PathBuf,
+ entry_type: EntryType,
+}
+
fn collect_files(
project: Model<Project>,
glob_inputs: &[String],
cx: &mut AppContext,
-) -> Task<Result<(String, Vec<(Range<usize>, PathBuf, EntryType)>)>> {
+) -> Task<Result<FileCommandOutput>> {
let Ok(matchers) = glob_inputs
.into_iter()
.map(|glob_input| {
- PathMatcher::new(&[glob_input.to_owned()])
+ custom_path_matcher::PathMatcher::new(&[glob_input.to_owned()])
.with_context(|| format!("invalid path {glob_input}"))
})
- .collect::<anyhow::Result<Vec<PathMatcher>>>()
+ .collect::<anyhow::Result<Vec<custom_path_matcher::PathMatcher>>>()
else {
return Task::ready(Err(anyhow!("invalid path")));
};
@@ -238,6 +256,7 @@ fn collect_files(
.worktrees(cx)
.map(|worktree| worktree.read(cx).snapshot())
.collect::<Vec<_>>();
+
cx.spawn(|mut cx| async move {
let mut text = String::new();
let mut ranges = Vec::new();
@@ -246,10 +265,12 @@ fn collect_files(
let mut directory_stack: Vec<(Arc<Path>, String, usize)> = Vec::new();
let mut folded_directory_names_stack = Vec::new();
let mut is_top_level_directory = true;
+
for entry in snapshot.entries(false, 0) {
let mut path_including_worktree_name = PathBuf::new();
path_including_worktree_name.push(snapshot.root_name());
path_including_worktree_name.push(&entry.path);
+
if !matchers
.iter()
.any(|matcher| matcher.is_match(&path_including_worktree_name))
@@ -262,11 +283,11 @@ fn collect_files(
break;
}
let (_, entry_name, start) = directory_stack.pop().unwrap();
- ranges.push((
- start..text.len().saturating_sub(1),
- PathBuf::from(entry_name),
- EntryType::Directory,
- ));
+ ranges.push(OutputFile {
+ range_in_text: start..text.len().saturating_sub(1),
+ path: PathBuf::from(entry_name),
+ entry_type: EntryType::Directory,
+ });
}
let filename = entry
@@ -339,24 +360,39 @@ fn collect_files(
) {
text.pop();
}
- ranges.push((
- prev_len..text.len(),
- path_including_worktree_name,
- EntryType::File,
- ));
+ ranges.push(OutputFile {
+ range_in_text: prev_len..text.len(),
+ path: path_including_worktree_name,
+ entry_type: EntryType::File,
+ });
text.push('\n');
}
}
}
- while let Some((dir, _, start)) = directory_stack.pop() {
- let mut root_path = PathBuf::new();
- root_path.push(snapshot.root_name());
- root_path.push(&dir);
- ranges.push((start..text.len(), root_path, EntryType::Directory));
+ while let Some((dir, entry, start)) = directory_stack.pop() {
+ if directory_stack.is_empty() {
+ let mut root_path = PathBuf::new();
+ root_path.push(snapshot.root_name());
+ root_path.push(&dir);
+ ranges.push(OutputFile {
+ range_in_text: start..text.len(),
+ path: root_path,
+ entry_type: EntryType::Directory,
+ });
+ } else {
+ ranges.push(OutputFile {
+ range_in_text: start..text.len(),
+ path: PathBuf::from(entry.as_str()),
+ entry_type: EntryType::Directory,
+ });
+ }
}
}
- Ok((text, ranges))
+ Ok(FileCommandOutput {
+ completion_text: text,
+ files: ranges,
+ })
})
}
@@ -424,3 +460,300 @@ pub fn build_entry_output_section(
label: label.into(),
}
}
+
+/// This contains a small fork of the util::paths::PathMatcher, that is stricter about the prefix
+/// check. Only subpaths pass the prefix check, rather than any prefix.
+mod custom_path_matcher {
+ use std::{fmt::Debug as _, path::Path};
+
+ use globset::{Glob, GlobSet, GlobSetBuilder};
+
+ #[derive(Clone, Debug, Default)]
+ pub struct PathMatcher {
+ sources: Vec<String>,
+ sources_with_trailing_slash: Vec<String>,
+ glob: GlobSet,
+ }
+
+ impl std::fmt::Display for PathMatcher {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.sources.fmt(f)
+ }
+ }
+
+ impl PartialEq for PathMatcher {
+ fn eq(&self, other: &Self) -> bool {
+ self.sources.eq(&other.sources)
+ }
+ }
+
+ impl Eq for PathMatcher {}
+
+ impl PathMatcher {
+ pub fn new(globs: &[String]) -> Result<Self, globset::Error> {
+ let globs = globs
+ .into_iter()
+ .map(|glob| Glob::new(&glob))
+ .collect::<Result<Vec<_>, _>>()?;
+ let sources = globs.iter().map(|glob| glob.glob().to_owned()).collect();
+ let sources_with_trailing_slash = globs
+ .iter()
+ .map(|glob| glob.glob().to_string() + std::path::MAIN_SEPARATOR_STR)
+ .collect();
+ let mut glob_builder = GlobSetBuilder::new();
+ for single_glob in globs {
+ glob_builder.add(single_glob);
+ }
+ let glob = glob_builder.build()?;
+ Ok(PathMatcher {
+ glob,
+ sources,
+ sources_with_trailing_slash,
+ })
+ }
+
+ pub fn sources(&self) -> &[String] {
+ &self.sources
+ }
+
+ pub fn is_match<P: AsRef<Path>>(&self, other: P) -> bool {
+ let other_path = other.as_ref();
+ self.sources
+ .iter()
+ .zip(self.sources_with_trailing_slash.iter())
+ .any(|(source, with_slash)| {
+ let as_bytes = other_path.as_os_str().as_encoded_bytes();
+ let with_slash = if source.ends_with("/") {
+ source.as_bytes()
+ } else {
+ with_slash.as_bytes()
+ };
+
+ as_bytes.starts_with(with_slash) || as_bytes.ends_with(source.as_bytes())
+ })
+ || self.glob.is_match(other_path)
+ || self.check_with_end_separator(other_path)
+ }
+
+ fn check_with_end_separator(&self, path: &Path) -> bool {
+ let path_str = path.to_string_lossy();
+ let separator = std::path::MAIN_SEPARATOR_STR;
+ if path_str.ends_with(separator) {
+ return false;
+ } else {
+ self.glob.is_match(path_str.to_string() + separator)
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use fs::FakeFs;
+ use gpui::TestAppContext;
+ use project::Project;
+ use serde_json::json;
+ use settings::SettingsStore;
+
+ use crate::slash_command::file_command::collect_files;
+
+ pub fn init_test(cx: &mut gpui::TestAppContext) {
+ if std::env::var("RUST_LOG").is_ok() {
+ env_logger::try_init().ok();
+ }
+
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ // release_channel::init(SemanticVersion::default(), cx);
+ language::init(cx);
+ Project::init_settings(cx);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_file_exact_matching(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ fs.insert_tree(
+ "/root",
+ json!({
+ "dir": {
+ "subdir": {
+ "file_0": "0"
+ },
+ "file_1": "1",
+ "file_2": "2",
+ "file_3": "3",
+ },
+ "dir.rs": "4"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, ["/root".as_ref()], cx).await;
+
+ let result_1 = cx
+ .update(|cx| collect_files(project.clone(), &["root/dir".to_string()], cx))
+ .await
+ .unwrap();
+
+ assert!(result_1.completion_text.starts_with("root/dir"));
+ // 4 files + 2 directories
+ assert_eq!(6, result_1.files.len());
+
+ let result_2 = cx
+ .update(|cx| collect_files(project.clone(), &["root/dir/".to_string()], cx))
+ .await
+ .unwrap();
+
+ assert_eq!(result_1, result_2);
+
+ let result = cx
+ .update(|cx| collect_files(project.clone(), &["root/dir*".to_string()], cx))
+ .await
+ .unwrap();
+
+ assert!(result.completion_text.starts_with("root/dir"));
+ // 5 files + 2 directories
+ assert_eq!(7, result.files.len());
+
+ // Ensure that the project lasts until after the last await
+ drop(project);
+ }
+
+ #[gpui::test]
+ async fn test_file_sub_directory_rendering(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ fs.insert_tree(
+ "/zed",
+ json!({
+ "assets": {
+ "dir1": {
+ ".gitkeep": ""
+ },
+ "dir2": {
+ ".gitkeep": ""
+ },
+ "themes": {
+ "ayu": {
+ "LICENSE": "1",
+ },
+ "andromeda": {
+ "LICENSE": "2",
+ },
+ "summercamp": {
+ "LICENSE": "3",
+ },
+ },
+ },
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, ["/zed".as_ref()], cx).await;
+
+ let result = cx
+ .update(|cx| collect_files(project.clone(), &["zed/assets/themes".to_string()], cx))
+ .await
+ .unwrap();
+
+ // Sanity check
+ assert!(result.completion_text.starts_with("zed/assets/themes\n"));
+ assert_eq!(7, result.files.len());
+
+ // Ensure that full file paths are included in the real output
+ assert!(result
+ .completion_text
+ .contains("zed/assets/themes/andromeda/LICENSE"));
+ assert!(result
+ .completion_text
+ .contains("zed/assets/themes/ayu/LICENSE"));
+ assert!(result
+ .completion_text
+ .contains("zed/assets/themes/summercamp/LICENSE"));
+
+ assert_eq!("summercamp", result.files[5].path.to_string_lossy());
+
+ // Ensure that things are in descending order, with properly relativized paths
+ assert_eq!(
+ "zed/assets/themes/andromeda/LICENSE",
+ result.files[0].path.to_string_lossy()
+ );
+ assert_eq!("andromeda", result.files[1].path.to_string_lossy());
+ assert_eq!(
+ "zed/assets/themes/ayu/LICENSE",
+ result.files[2].path.to_string_lossy()
+ );
+ assert_eq!("ayu", result.files[3].path.to_string_lossy());
+ assert_eq!(
+ "zed/assets/themes/summercamp/LICENSE",
+ result.files[4].path.to_string_lossy()
+ );
+
+ // Ensure that the project lasts until after the last await
+ drop(project);
+ }
+
+ #[gpui::test]
+ async fn test_file_deep_sub_directory_rendering(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ fs.insert_tree(
+ "/zed",
+ json!({
+ "assets": {
+ "themes": {
+ "LICENSE": "1",
+ "summercamp": {
+ "LICENSE": "1",
+ "subdir": {
+ "LICENSE": "1",
+ "subsubdir": {
+ "LICENSE": "3",
+ }
+ }
+ },
+ },
+ },
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs, ["/zed".as_ref()], cx).await;
+
+ let result = cx
+ .update(|cx| collect_files(project.clone(), &["zed/assets/themes".to_string()], cx))
+ .await
+ .unwrap();
+
+ assert!(result.completion_text.starts_with("zed/assets/themes\n"));
+ assert_eq!(
+ "zed/assets/themes/LICENSE",
+ result.files[0].path.to_string_lossy()
+ );
+ assert_eq!(
+ "zed/assets/themes/summercamp/LICENSE",
+ result.files[1].path.to_string_lossy()
+ );
+ assert_eq!(
+ "zed/assets/themes/summercamp/subdir/LICENSE",
+ result.files[2].path.to_string_lossy()
+ );
+ assert_eq!(
+ "zed/assets/themes/summercamp/subdir/subsubdir/LICENSE",
+ result.files[3].path.to_string_lossy()
+ );
+ assert_eq!("subsubdir", result.files[4].path.to_string_lossy());
+ assert_eq!("subdir", result.files[5].path.to_string_lossy());
+ assert_eq!("summercamp", result.files[6].path.to_string_lossy());
+ assert_eq!("zed/assets/themes", result.files[7].path.to_string_lossy());
+
+ // Ensure that the project lasts until after the last await
+ drop(project);
+ }
+}
@@ -45,7 +45,7 @@ impl SlashCommand for PromptSlashCommand {
Some(ArgumentCompletion {
label: prompt_title.clone().into(),
new_text: prompt_title,
- run_command: true,
+ after_completion: true.into(),
replace_previous_arguments: true,
})
})
@@ -39,6 +39,10 @@ impl SlashCommand for TabSlashCommand {
false
}
+ fn accepts_arguments(&self) -> bool {
+ true
+ }
+
fn complete_argument(
self: Arc<Self>,
arguments: &[String],
@@ -90,19 +94,20 @@ impl SlashCommand for TabSlashCommand {
label: path_string.clone().into(),
new_text: path_string,
replace_previous_arguments: false,
- run_command,
+ after_completion: run_command.into(),
})
});
- let active_item_completion = active_item_path.as_deref().map(|active_item_path| {
- let path_string = active_item_path.to_string_lossy().to_string();
- ArgumentCompletion {
+ let active_item_completion = active_item_path
+ .as_deref()
+ .map(|active_item_path| active_item_path.to_string_lossy().to_string())
+ .filter(|path_string| !argument_set.contains(path_string))
+ .map(|path_string| ArgumentCompletion {
label: path_string.clone().into(),
new_text: path_string,
replace_previous_arguments: false,
- run_command,
- }
- });
+ after_completion: run_command.into(),
+ });
Ok(active_item_completion
.into_iter()
@@ -110,7 +115,7 @@ impl SlashCommand for TabSlashCommand {
label: ALL_TABS_COMPLETION_ITEM.into(),
new_text: ALL_TABS_COMPLETION_ITEM.to_owned(),
replace_previous_arguments: false,
- run_command: true,
+ after_completion: true.into(),
}))
.chain(tab_completion_items)
.collect())
@@ -40,6 +40,10 @@ impl SlashCommand for TerminalSlashCommand {
false
}
+ fn accepts_arguments(&self) -> bool {
+ true
+ }
+
fn complete_argument(
self: Arc<Self>,
_arguments: &[String],
@@ -277,6 +277,7 @@ impl TerminalInlineAssistant {
messages.push(LanguageModelRequestMessage {
role: Role::User,
content: vec![prompt.into()],
+ cache: false,
});
Ok(LanguageModelRequest {
@@ -1,25 +1,42 @@
-### Using the Assistant
+## Assistant Panel
Once you have configured a provider, you can interact with the provider's language models in a context editor.
-To create a new context editor, use the menu in the top right of the assistant panel and the `New Context` option.
+To create a new context editor, use the menu in the top right of the assistant panel and select the `New Context` option.
In the context editor, select a model from one of the configured providers, type a message in the `You` block, and submit with `cmd-enter` (or `ctrl-enter` on Linux).
-### Inline assistant
-
-When you're in a normal editor, you can use `ctrl-enter` to open the inline assistant.
-
-The inline assistant allows you to send the current selection (or the current line) to a language model and modify the selection with the language model's response.
-
### Adding Prompts
-You can customize the default prompts that are used in new context editor, by opening the `Prompt Library`.
+You can customize the default prompts used in new context editors by opening the `Prompt Library`.
Open the `Prompt Library` using either the menu in the top right of the assistant panel and choosing the `Prompt Library` option, or by using the `assistant: deploy prompt library` command when the assistant panel is focused.
### Viewing past contexts
-You view all previous contexts by opening up the `History` tab in the assistant panel.
+You can view all previous contexts by opening the `History` tab in the assistant panel.
+
+Open the `History` using the menu in the top right of the assistant panel and choosing `History`.
+
+### Slash commands
+
+Slash commands enhance the assistant's capabilities. Begin by typing a `/` at the beginning of the line to see a list of available commands:
+
+- default: Inserts the default prompt into the context
+- diagnostics: Injects errors reported by the project's language server into the context
+- fetch: Pulls the content of a webpage and inserts it into the context
+- file: Pulls a single file or a directory of files into the context
+- now: Inserts the current date and time into the context
+- prompt: Adds a custom-configured prompt to the context (see Prompt Library)
+- search: Performs semantic search for content in your project based on natural language
+- symbols: Pulls the current tab's active symbols into the context
+- tab: Pulls in the content of the active tab or all open tabs into the context
+- terminal: Pulls in a select number of lines of output from the terminal
+
+## Inline assistant
+
+You can use `ctrl-enter` to open the inline assistant in both a normal editor and within the assistant panel.
+
+The inline assistant allows you to send the current selection (or the current line) to a language model and modify the selection with the language model's response.
-Open the `History` using the menu in the top right of the assistant panel and choosing the `History`.
+The inline assistant pulls its context from the assistant panel, allowing you to provide additional instructions or rules for code transformations.
@@ -0,0 +1,763 @@
+mod step_view;
+
+use crate::{
+ prompts::StepResolutionContext, AssistantPanel, Context, InlineAssistId, InlineAssistant,
+};
+use anyhow::{anyhow, Error, Result};
+use collections::HashMap;
+use editor::Editor;
+use futures::future;
+use gpui::{
+ Model, ModelContext, Task, UpdateGlobal as _, View, WeakModel, WeakView, WindowContext,
+};
+use language::{Anchor, Buffer, BufferSnapshot, SymbolPath};
+use language_model::{LanguageModelRegistry, LanguageModelRequestMessage, Role};
+use project::Project;
+use rope::Point;
+use serde::{Deserialize, Serialize};
+use smol::stream::StreamExt;
+use std::{cmp, fmt::Write, ops::Range, sync::Arc};
+use text::{AnchorRangeExt as _, OffsetRangeExt as _};
+use util::ResultExt as _;
+use workspace::Workspace;
+
+pub use step_view::WorkflowStepView;
+
+pub struct WorkflowStep {
+ context: WeakModel<Context>,
+ context_buffer_range: Range<Anchor>,
+ tool_output: String,
+ resolve_task: Option<Task<()>>,
+ pub resolution: Option<Result<WorkflowStepResolution, Arc<Error>>>,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct WorkflowStepResolution {
+ pub title: String,
+ pub suggestion_groups: HashMap<Model<Buffer>, Vec<WorkflowSuggestionGroup>>,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct WorkflowSuggestionGroup {
+ pub context_range: Range<language::Anchor>,
+ pub suggestions: Vec<WorkflowSuggestion>,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum WorkflowSuggestion {
+ Update {
+ symbol_path: SymbolPath,
+ range: Range<language::Anchor>,
+ description: String,
+ },
+ CreateFile {
+ description: String,
+ },
+ InsertSiblingBefore {
+ symbol_path: SymbolPath,
+ position: language::Anchor,
+ description: String,
+ },
+ InsertSiblingAfter {
+ symbol_path: SymbolPath,
+ position: language::Anchor,
+ description: String,
+ },
+ PrependChild {
+ symbol_path: Option<SymbolPath>,
+ position: language::Anchor,
+ description: String,
+ },
+ AppendChild {
+ symbol_path: Option<SymbolPath>,
+ position: language::Anchor,
+ description: String,
+ },
+ Delete {
+ symbol_path: SymbolPath,
+ range: Range<language::Anchor>,
+ },
+}
+
+impl WorkflowStep {
+ pub fn new(range: Range<Anchor>, context: WeakModel<Context>) -> Self {
+ Self {
+ context_buffer_range: range,
+ tool_output: String::new(),
+ context,
+ resolution: None,
+ resolve_task: None,
+ }
+ }
+
+ pub fn resolve(&mut self, cx: &mut ModelContext<WorkflowStep>) -> Option<()> {
+ let range = self.context_buffer_range.clone();
+ let context = self.context.upgrade()?;
+ let context = context.read(cx);
+ let project = context.project()?;
+ let prompt_builder = context.prompt_builder();
+ let mut request = context.to_completion_request(cx);
+ let model = LanguageModelRegistry::read_global(cx).active_model();
+ let context_buffer = context.buffer();
+ let step_text = context_buffer
+ .read(cx)
+ .text_for_range(range.clone())
+ .collect::<String>();
+
+ let mut workflow_context = String::new();
+ for message in context.messages(cx) {
+ write!(&mut workflow_context, "<message role={}>", message.role).unwrap();
+ for chunk in context_buffer.read(cx).text_for_range(message.offset_range) {
+ write!(&mut workflow_context, "{chunk}").unwrap();
+ }
+ write!(&mut workflow_context, "</message>").unwrap();
+ }
+
+ self.resolve_task = Some(cx.spawn(|this, mut cx| async move {
+ let result = async {
+ let Some(model) = model else {
+ return Err(anyhow!("no model selected"));
+ };
+
+ this.update(&mut cx, |this, cx| {
+ this.tool_output.clear();
+ this.resolution = None;
+ this.result_updated(cx);
+ cx.notify();
+ })?;
+
+ let resolution_context = StepResolutionContext {
+ workflow_context,
+ step_to_resolve: step_text.clone(),
+ };
+ let mut prompt =
+ prompt_builder.generate_step_resolution_prompt(&resolution_context)?;
+ prompt.push_str(&step_text);
+ request.messages.push(LanguageModelRequestMessage {
+ role: Role::User,
+ content: vec![prompt.into()],
+ cache: false,
+ });
+
+ // Invoke the model to get its edit suggestions for this workflow step.
+ let mut stream = model
+ .use_tool_stream::<tool::WorkflowStepResolutionTool>(request, &cx)
+ .await?;
+ while let Some(chunk) = stream.next().await {
+ let chunk = chunk?;
+ this.update(&mut cx, |this, cx| {
+ this.tool_output.push_str(&chunk);
+ cx.notify();
+ })?;
+ }
+
+ let resolution = this.update(&mut cx, |this, _| {
+ serde_json::from_str::<tool::WorkflowStepResolutionTool>(&this.tool_output)
+ })??;
+
+ this.update(&mut cx, |this, cx| {
+ this.tool_output = serde_json::to_string_pretty(&resolution).unwrap();
+ cx.notify();
+ })?;
+
+ // Translate the parsed suggestions to our internal types, which anchor the suggestions to locations in the code.
+ let suggestion_tasks: Vec<_> = resolution
+ .suggestions
+ .iter()
+ .map(|suggestion| suggestion.resolve(project.clone(), cx.clone()))
+ .collect();
+
+ // Expand the context ranges of each suggestion and group suggestions with overlapping context ranges.
+ let suggestions = future::join_all(suggestion_tasks)
+ .await
+ .into_iter()
+ .filter_map(|task| task.log_err())
+ .collect::<Vec<_>>();
+
+ let mut suggestions_by_buffer = HashMap::default();
+ for (buffer, suggestion) in suggestions {
+ suggestions_by_buffer
+ .entry(buffer)
+ .or_insert_with(Vec::new)
+ .push(suggestion);
+ }
+
+ let mut suggestion_groups_by_buffer = HashMap::default();
+ for (buffer, mut suggestions) in suggestions_by_buffer {
+ let mut suggestion_groups = Vec::<WorkflowSuggestionGroup>::new();
+ let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
+ // Sort suggestions by their range so that earlier, larger ranges come first
+ suggestions.sort_by(|a, b| a.range().cmp(&b.range(), &snapshot));
+
+ // Merge overlapping suggestions
+ suggestions.dedup_by(|a, b| b.try_merge(a, &snapshot));
+
+ // Create context ranges for each suggestion
+ for suggestion in suggestions {
+ let context_range = {
+ let suggestion_point_range = suggestion.range().to_point(&snapshot);
+ let start_row = suggestion_point_range.start.row.saturating_sub(5);
+ let end_row = cmp::min(
+ suggestion_point_range.end.row + 5,
+ snapshot.max_point().row,
+ );
+ let start = snapshot.anchor_before(Point::new(start_row, 0));
+ let end = snapshot
+ .anchor_after(Point::new(end_row, snapshot.line_len(end_row)));
+ start..end
+ };
+
+ if let Some(last_group) = suggestion_groups.last_mut() {
+ if last_group
+ .context_range
+ .end
+ .cmp(&context_range.start, &snapshot)
+ .is_ge()
+ {
+ // Merge with the previous group if context ranges overlap
+ last_group.context_range.end = context_range.end;
+ last_group.suggestions.push(suggestion);
+ } else {
+ // Create a new group
+ suggestion_groups.push(WorkflowSuggestionGroup {
+ context_range,
+ suggestions: vec![suggestion],
+ });
+ }
+ } else {
+ // Create the first group
+ suggestion_groups.push(WorkflowSuggestionGroup {
+ context_range,
+ suggestions: vec![suggestion],
+ });
+ }
+ }
+
+ suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
+ }
+
+ Ok((resolution.step_title, suggestion_groups_by_buffer))
+ };
+
+ let result = result.await;
+ this.update(&mut cx, |this, cx| {
+ this.resolution = Some(match result {
+ Ok((title, suggestion_groups)) => Ok(WorkflowStepResolution {
+ title,
+ suggestion_groups,
+ }),
+ Err(error) => Err(Arc::new(error)),
+ });
+ this.context
+ .update(cx, |context, cx| context.workflow_step_updated(range, cx))
+ .ok();
+ cx.notify();
+ })
+ .ok();
+ }));
+ None
+ }
+
+ fn result_updated(&mut self, cx: &mut ModelContext<Self>) {
+ self.context
+ .update(cx, |context, cx| {
+ context.workflow_step_updated(self.context_buffer_range.clone(), cx)
+ })
+ .ok();
+ }
+}
+
+impl WorkflowSuggestion {
+ pub fn range(&self) -> Range<language::Anchor> {
+ match self {
+ Self::Update { range, .. } => range.clone(),
+ Self::CreateFile { .. } => language::Anchor::MIN..language::Anchor::MAX,
+ Self::InsertSiblingBefore { position, .. }
+ | Self::InsertSiblingAfter { position, .. }
+ | Self::PrependChild { position, .. }
+ | Self::AppendChild { position, .. } => *position..*position,
+ Self::Delete { range, .. } => range.clone(),
+ }
+ }
+
+ pub fn description(&self) -> Option<&str> {
+ match self {
+ Self::Update { description, .. }
+ | Self::CreateFile { description }
+ | Self::InsertSiblingBefore { description, .. }
+ | Self::InsertSiblingAfter { description, .. }
+ | Self::PrependChild { description, .. }
+ | Self::AppendChild { description, .. } => Some(description),
+ Self::Delete { .. } => None,
+ }
+ }
+
+ fn description_mut(&mut self) -> Option<&mut String> {
+ match self {
+ Self::Update { description, .. }
+ | Self::CreateFile { description }
+ | Self::InsertSiblingBefore { description, .. }
+ | Self::InsertSiblingAfter { description, .. }
+ | Self::PrependChild { description, .. }
+ | Self::AppendChild { description, .. } => Some(description),
+ Self::Delete { .. } => None,
+ }
+ }
+
+ fn symbol_path(&self) -> Option<&SymbolPath> {
+ match self {
+ Self::Update { symbol_path, .. } => Some(symbol_path),
+ Self::InsertSiblingBefore { symbol_path, .. } => Some(symbol_path),
+ Self::InsertSiblingAfter { symbol_path, .. } => Some(symbol_path),
+ Self::PrependChild { symbol_path, .. } => symbol_path.as_ref(),
+ Self::AppendChild { symbol_path, .. } => symbol_path.as_ref(),
+ Self::Delete { symbol_path, .. } => Some(symbol_path),
+ Self::CreateFile { .. } => None,
+ }
+ }
+
+ fn kind(&self) -> &str {
+ match self {
+ Self::Update { .. } => "Update",
+ Self::CreateFile { .. } => "CreateFile",
+ Self::InsertSiblingBefore { .. } => "InsertSiblingBefore",
+ Self::InsertSiblingAfter { .. } => "InsertSiblingAfter",
+ Self::PrependChild { .. } => "PrependChild",
+ Self::AppendChild { .. } => "AppendChild",
+ Self::Delete { .. } => "Delete",
+ }
+ }
+
+ fn try_merge(&mut self, other: &Self, buffer: &BufferSnapshot) -> bool {
+ let range = self.range();
+ let other_range = other.range();
+
+ // Don't merge if we don't contain the other suggestion.
+ if range.start.cmp(&other_range.start, buffer).is_gt()
+ || range.end.cmp(&other_range.end, buffer).is_lt()
+ {
+ return false;
+ }
+
+ if let Some(description) = self.description_mut() {
+ if let Some(other_description) = other.description() {
+ description.push('\n');
+ description.push_str(other_description);
+ }
+ }
+ true
+ }
+
+ pub fn show(
+ &self,
+ editor: &View<Editor>,
+ excerpt_id: editor::ExcerptId,
+ workspace: &WeakView<Workspace>,
+ assistant_panel: &View<AssistantPanel>,
+ cx: &mut WindowContext,
+ ) -> Option<InlineAssistId> {
+ let mut initial_transaction_id = None;
+ let initial_prompt;
+ let suggestion_range;
+ let buffer = editor.read(cx).buffer().clone();
+ let snapshot = buffer.read(cx).snapshot(cx);
+
+ match self {
+ Self::Update {
+ range, description, ..
+ } => {
+ initial_prompt = description.clone();
+ suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
+ ..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
+ }
+ Self::CreateFile { description } => {
+ initial_prompt = description.clone();
+ suggestion_range = editor::Anchor::min()..editor::Anchor::min();
+ }
+ Self::InsertSiblingBefore {
+ position,
+ description,
+ ..
+ } => {
+ let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
+ initial_prompt = description.clone();
+ suggestion_range = buffer.update(cx, |buffer, cx| {
+ buffer.start_transaction(cx);
+ let line_start = buffer.insert_empty_line(position, true, true, cx);
+ initial_transaction_id = buffer.end_transaction(cx);
+ buffer.refresh_preview(cx);
+
+ let line_start = buffer.read(cx).anchor_before(line_start);
+ line_start..line_start
+ });
+ }
+ Self::InsertSiblingAfter {
+ position,
+ description,
+ ..
+ } => {
+ let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
+ initial_prompt = description.clone();
+ suggestion_range = buffer.update(cx, |buffer, cx| {
+ buffer.start_transaction(cx);
+ let line_start = buffer.insert_empty_line(position, true, true, cx);
+ initial_transaction_id = buffer.end_transaction(cx);
+ buffer.refresh_preview(cx);
+
+ let line_start = buffer.read(cx).anchor_before(line_start);
+ line_start..line_start
+ });
+ }
+ Self::PrependChild {
+ position,
+ description,
+ ..
+ } => {
+ let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
+ initial_prompt = description.clone();
+ suggestion_range = buffer.update(cx, |buffer, cx| {
+ buffer.start_transaction(cx);
+ let line_start = buffer.insert_empty_line(position, false, true, cx);
+ initial_transaction_id = buffer.end_transaction(cx);
+ buffer.refresh_preview(cx);
+
+ let line_start = buffer.read(cx).anchor_before(line_start);
+ line_start..line_start
+ });
+ }
+ Self::AppendChild {
+ position,
+ description,
+ ..
+ } => {
+ let position = snapshot.anchor_in_excerpt(excerpt_id, *position)?;
+ initial_prompt = description.clone();
+ suggestion_range = buffer.update(cx, |buffer, cx| {
+ buffer.start_transaction(cx);
+ let line_start = buffer.insert_empty_line(position, true, false, cx);
+ initial_transaction_id = buffer.end_transaction(cx);
+ buffer.refresh_preview(cx);
+
+ let line_start = buffer.read(cx).anchor_before(line_start);
+ line_start..line_start
+ });
+ }
+ Self::Delete { range, .. } => {
+ initial_prompt = "Delete".to_string();
+ suggestion_range = snapshot.anchor_in_excerpt(excerpt_id, range.start)?
+ ..snapshot.anchor_in_excerpt(excerpt_id, range.end)?;
+ }
+ }
+
+ InlineAssistant::update_global(cx, |inline_assistant, cx| {
+ Some(inline_assistant.suggest_assist(
+ editor,
+ suggestion_range,
+ initial_prompt,
+ initial_transaction_id,
+ Some(workspace.clone()),
+ Some(assistant_panel),
+ cx,
+ ))
+ })
+ }
+}
+
+pub mod tool {
+ use super::*;
+ use anyhow::Context as _;
+ use gpui::AsyncAppContext;
+ use language::ParseStatus;
+ use language_model::LanguageModelTool;
+ use project::ProjectPath;
+ use schemars::JsonSchema;
+ use std::path::Path;
+
+ #[derive(Debug, Serialize, Deserialize, JsonSchema)]
+ pub struct WorkflowStepResolutionTool {
+ /// An extremely short title for the edit step represented by these operations.
+ pub step_title: String,
+ /// A sequence of operations to apply to the codebase.
+ /// When multiple operations are required for a step, be sure to include multiple operations in this list.
+ pub suggestions: Vec<WorkflowSuggestionTool>,
+ }
+
+ impl LanguageModelTool for WorkflowStepResolutionTool {
+ fn name() -> String {
+ "edit".into()
+ }
+
+ fn description() -> String {
+ "suggest edits to one or more locations in the codebase".into()
+ }
+ }
+
+ /// A description of an operation to apply to one location in the codebase.
+ ///
+ /// This object represents a single edit operation that can be performed on a specific file
+ /// in the codebase. It encapsulates both the location (file path) and the nature of the
+ /// edit to be made.
+ ///
+ /// # Fields
+ ///
+ /// * `path`: A string representing the file path where the edit operation should be applied.
+ /// This path is relative to the root of the project or repository.
+ ///
+ /// * `kind`: An enum representing the specific type of edit operation to be performed.
+ ///
+ /// # Usage
+ ///
+ /// `EditOperation` is used within a code editor to represent and apply
+ /// programmatic changes to source code. It provides a structured way to describe
+ /// edits for features like refactoring tools or AI-assisted coding suggestions.
+ #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
+ pub struct WorkflowSuggestionTool {
+ /// The path to the file containing the relevant operation
+ pub path: String,
+ #[serde(flatten)]
+ pub kind: WorkflowSuggestionToolKind,
+ }
+
+ impl WorkflowSuggestionTool {
+ pub(super) async fn resolve(
+ &self,
+ project: Model<Project>,
+ mut cx: AsyncAppContext,
+ ) -> Result<(Model<Buffer>, super::WorkflowSuggestion)> {
+ let path = self.path.clone();
+ let kind = self.kind.clone();
+ let buffer = project
+ .update(&mut cx, |project, cx| {
+ let project_path = project
+ .find_project_path(Path::new(&path), cx)
+ .or_else(|| {
+ // If we couldn't find a project path for it, put it in the active worktree
+ // so that when we create the buffer, it can be saved.
+ let worktree = project
+ .active_entry()
+ .and_then(|entry_id| project.worktree_for_entry(entry_id, cx))
+ .or_else(|| project.worktrees(cx).next())?;
+ let worktree = worktree.read(cx);
+
+ Some(ProjectPath {
+ worktree_id: worktree.id(),
+ path: Arc::from(Path::new(&path)),
+ })
+ })
+ .with_context(|| format!("worktree not found for {:?}", path))?;
+ anyhow::Ok(project.open_buffer(project_path, cx))
+ })??
+ .await?;
+
+ let mut parse_status = buffer.read_with(&cx, |buffer, _cx| buffer.parse_status())?;
+ while *parse_status.borrow() != ParseStatus::Idle {
+ parse_status.changed().await?;
+ }
+
+ let snapshot = buffer.update(&mut cx, |buffer, _| buffer.snapshot())?;
+ let outline = snapshot.outline(None).context("no outline for buffer")?;
+
+ let suggestion = match kind {
+ WorkflowSuggestionToolKind::Update {
+ symbol,
+ description,
+ } => {
+ let (symbol_path, symbol) = outline
+ .find_most_similar(&symbol)
+ .with_context(|| format!("symbol not found: {:?}", symbol))?;
+ let symbol = symbol.to_point(&snapshot);
+ let start = symbol
+ .annotation_range
+ .map_or(symbol.range.start, |range| range.start);
+ let start = Point::new(start.row, 0);
+ let end = Point::new(
+ symbol.range.end.row,
+ snapshot.line_len(symbol.range.end.row),
+ );
+ let range = snapshot.anchor_before(start)..snapshot.anchor_after(end);
+ WorkflowSuggestion::Update {
+ range,
+ description,
+ symbol_path,
+ }
+ }
+ WorkflowSuggestionToolKind::Create { description } => {
+ WorkflowSuggestion::CreateFile { description }
+ }
+ WorkflowSuggestionToolKind::InsertSiblingBefore {
+ symbol,
+ description,
+ } => {
+ let (symbol_path, symbol) = outline
+ .find_most_similar(&symbol)
+ .with_context(|| format!("symbol not found: {:?}", symbol))?;
+ let symbol = symbol.to_point(&snapshot);
+ let position = snapshot.anchor_before(
+ symbol
+ .annotation_range
+ .map_or(symbol.range.start, |annotation_range| {
+ annotation_range.start
+ }),
+ );
+ WorkflowSuggestion::InsertSiblingBefore {
+ position,
+ description,
+ symbol_path,
+ }
+ }
+ WorkflowSuggestionToolKind::InsertSiblingAfter {
+ symbol,
+ description,
+ } => {
+ let (symbol_path, symbol) = outline
+ .find_most_similar(&symbol)
+ .with_context(|| format!("symbol not found: {:?}", symbol))?;
+ let symbol = symbol.to_point(&snapshot);
+ let position = snapshot.anchor_after(symbol.range.end);
+ WorkflowSuggestion::InsertSiblingAfter {
+ position,
+ description,
+ symbol_path,
+ }
+ }
+ WorkflowSuggestionToolKind::PrependChild {
+ symbol,
+ description,
+ } => {
+ if let Some(symbol) = symbol {
+ let (symbol_path, symbol) = outline
+ .find_most_similar(&symbol)
+ .with_context(|| format!("symbol not found: {:?}", symbol))?;
+ let symbol = symbol.to_point(&snapshot);
+
+ let position = snapshot.anchor_after(
+ symbol
+ .body_range
+ .map_or(symbol.range.start, |body_range| body_range.start),
+ );
+ WorkflowSuggestion::PrependChild {
+ position,
+ description,
+ symbol_path: Some(symbol_path),
+ }
+ } else {
+ WorkflowSuggestion::PrependChild {
+ position: language::Anchor::MIN,
+ description,
+ symbol_path: None,
+ }
+ }
+ }
+ WorkflowSuggestionToolKind::AppendChild {
+ symbol,
+ description,
+ } => {
+ if let Some(symbol) = symbol {
+ let (symbol_path, symbol) = outline
+ .find_most_similar(&symbol)
+ .with_context(|| format!("symbol not found: {:?}", symbol))?;
+ let symbol = symbol.to_point(&snapshot);
+
+ let position = snapshot.anchor_before(
+ symbol
+ .body_range
+ .map_or(symbol.range.end, |body_range| body_range.end),
+ );
+ WorkflowSuggestion::AppendChild {
+ position,
+ description,
+ symbol_path: Some(symbol_path),
+ }
+ } else {
+ WorkflowSuggestion::PrependChild {
+ position: language::Anchor::MAX,
+ description,
+ symbol_path: None,
+ }
+ }
+ }
+ WorkflowSuggestionToolKind::Delete { symbol } => {
+ let (symbol_path, symbol) = outline
+ .find_most_similar(&symbol)
+ .with_context(|| format!("symbol not found: {:?}", symbol))?;
+ let symbol = symbol.to_point(&snapshot);
+ let start = symbol
+ .annotation_range
+ .map_or(symbol.range.start, |range| range.start);
+ let start = Point::new(start.row, 0);
+ let end = Point::new(
+ symbol.range.end.row,
+ snapshot.line_len(symbol.range.end.row),
+ );
+ let range = snapshot.anchor_before(start)..snapshot.anchor_after(end);
+ WorkflowSuggestion::Delete { range, symbol_path }
+ }
+ };
+
+ Ok((buffer, suggestion))
+ }
+ }
+
+ #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
+ #[serde(tag = "kind")]
+ pub enum WorkflowSuggestionToolKind {
+ /// Rewrites the specified symbol entirely based on the given description.
+ /// This operation completely replaces the existing symbol with new content.
+ Update {
+ /// A fully-qualified reference to the symbol, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
+ /// The path should uniquely identify the symbol within the containing file.
+ symbol: String,
+ /// A brief description of the transformation to apply to the symbol.
+ description: String,
+ },
+ /// Creates a new file with the given path based on the provided description.
+ /// This operation adds a new file to the codebase.
+ Create {
+ /// A brief description of the file to be created.
+ description: String,
+ },
+ /// Inserts a new symbol based on the given description before the specified symbol.
+ /// This operation adds new content immediately preceding an existing symbol.
+ InsertSiblingBefore {
+ /// A fully-qualified reference to the symbol, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
+ /// The new content will be inserted immediately before this symbol.
+ symbol: String,
+ /// A brief description of the new symbol to be inserted.
+ description: String,
+ },
+ /// Inserts a new symbol based on the given description after the specified symbol.
+ /// This operation adds new content immediately following an existing symbol.
+ InsertSiblingAfter {
+ /// A fully-qualified reference to the symbol, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
+ /// The new content will be inserted immediately after this symbol.
+ symbol: String,
+ /// A brief description of the new symbol to be inserted.
+ description: String,
+ },
+ /// Inserts a new symbol as a child of the specified symbol at the start.
+ /// This operation adds new content as the first child of an existing symbol (or file if no symbol is provided).
+ PrependChild {
+ /// An optional fully-qualified reference to the symbol after the code you want to insert, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
+ /// If provided, the new content will be inserted as the first child of this symbol.
+ /// If not provided, the new content will be inserted at the top of the file.
+ symbol: Option<String>,
+ /// A brief description of the new symbol to be inserted.
+ description: String,
+ },
+ /// Inserts a new symbol as a child of the specified symbol at the end.
+ /// This operation adds new content as the last child of an existing symbol (or file if no symbol is provided).
+ AppendChild {
+ /// An optional fully-qualified reference to the symbol before the code you want to insert, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
+ /// If provided, the new content will be inserted as the last child of this symbol.
+ /// If not provided, the new content will be applied at the bottom of the file.
+ symbol: Option<String>,
+ /// A brief description of the new symbol to be inserted.
+ description: String,
+ },
+ /// Deletes the specified symbol from the containing file.
+ Delete {
+ /// An fully-qualified reference to the symbol to be deleted, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
+ symbol: String,
+ },
+ }
+}
@@ -0,0 +1,315 @@
+use super::WorkflowStep;
+use crate::{Assist, Context};
+use editor::{
+ display_map::{BlockDisposition, BlockProperties, BlockStyle},
+ Editor, EditorEvent, ExcerptRange, MultiBuffer,
+};
+use gpui::{
+ div, AnyElement, AppContext, Context as _, Empty, EventEmitter, FocusableView, IntoElement,
+ Model, ParentElement as _, Render, SharedString, Styled as _, View, ViewContext,
+ VisualContext as _, WeakModel, WindowContext,
+};
+use language::{language_settings::SoftWrap, Anchor, Buffer, LanguageRegistry};
+use std::{ops::DerefMut, sync::Arc};
+use text::OffsetRangeExt;
+use theme::ActiveTheme as _;
+use ui::{
+ h_flex, v_flex, ButtonCommon as _, ButtonLike, ButtonStyle, Color, Icon, IconName,
+ InteractiveElement as _, Label, LabelCommon as _,
+};
+use workspace::{
+ item::{self, Item},
+ pane,
+ searchable::SearchableItemHandle,
+};
+
+pub struct WorkflowStepView {
+ step: WeakModel<WorkflowStep>,
+ tool_output_buffer: Model<Buffer>,
+ editor: View<Editor>,
+}
+
+impl WorkflowStepView {
+ pub fn new(
+ context: Model<Context>,
+ step: Model<WorkflowStep>,
+ language_registry: Arc<LanguageRegistry>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ let tool_output_buffer =
+ cx.new_model(|cx| Buffer::local(step.read(cx).tool_output.clone(), cx));
+ let buffer = cx.new_model(|cx| {
+ let mut buffer = MultiBuffer::without_headers(0, language::Capability::ReadWrite);
+ buffer.push_excerpts(
+ context.read(cx).buffer().clone(),
+ [ExcerptRange {
+ context: step.read(cx).context_buffer_range.clone(),
+ primary: None,
+ }],
+ cx,
+ );
+ buffer.push_excerpts(
+ tool_output_buffer.clone(),
+ [ExcerptRange {
+ context: Anchor::MIN..Anchor::MAX,
+ primary: None,
+ }],
+ cx,
+ );
+ buffer
+ });
+
+ let buffer_snapshot = buffer.read(cx).snapshot(cx);
+ let output_excerpt = buffer_snapshot.excerpts().skip(1).next().unwrap().0;
+ let input_start_anchor = multi_buffer::Anchor::min();
+ let output_start_anchor = buffer_snapshot
+ .anchor_in_excerpt(output_excerpt, Anchor::MIN)
+ .unwrap();
+ let output_end_anchor = multi_buffer::Anchor::max();
+
+ let handle = cx.view().downgrade();
+ let editor = cx.new_view(|cx| {
+ let mut editor = Editor::for_multibuffer(buffer.clone(), None, false, cx);
+ editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
+ editor.set_show_line_numbers(false, cx);
+ editor.set_show_git_diff_gutter(false, cx);
+ editor.set_show_code_actions(false, cx);
+ editor.set_show_runnables(false, cx);
+ editor.set_show_wrap_guides(false, cx);
+ editor.set_show_indent_guides(false, cx);
+ editor.set_read_only(true);
+ editor.set_show_inline_completions(false);
+ editor.insert_blocks(
+ [
+ BlockProperties {
+ position: input_start_anchor,
+ height: 1,
+ style: BlockStyle::Fixed,
+ render: Box::new(|cx| section_header("Step Input", cx)),
+ disposition: BlockDisposition::Above,
+ priority: 0,
+ },
+ BlockProperties {
+ position: output_start_anchor,
+ height: 1,
+ style: BlockStyle::Fixed,
+ render: Box::new(|cx| section_header("Tool Output", cx)),
+ disposition: BlockDisposition::Above,
+ priority: 0,
+ },
+ BlockProperties {
+ position: output_end_anchor,
+ height: 1,
+ style: BlockStyle::Fixed,
+ render: Box::new(move |cx| {
+ if let Some(result) = handle.upgrade().and_then(|this| {
+ this.update(cx.deref_mut(), |this, cx| this.render_result(cx))
+ }) {
+ v_flex()
+ .child(section_header("Output", cx))
+ .child(
+ div().pl(cx.gutter_dimensions.full_width()).child(result),
+ )
+ .into_any_element()
+ } else {
+ Empty.into_any_element()
+ }
+ }),
+ disposition: BlockDisposition::Below,
+ priority: 0,
+ },
+ ],
+ None,
+ cx,
+ );
+ editor
+ });
+
+ cx.observe(&step, Self::step_updated).detach();
+ cx.observe_release(&step, Self::step_released).detach();
+
+ cx.spawn(|this, mut cx| async move {
+ if let Ok(language) = language_registry.language_for_name("JSON").await {
+ this.update(&mut cx, |this, cx| {
+ this.tool_output_buffer.update(cx, |buffer, cx| {
+ buffer.set_language(Some(language), cx);
+ });
+ })
+ .ok();
+ }
+ })
+ .detach();
+
+ Self {
+ tool_output_buffer,
+ step: step.downgrade(),
+ editor,
+ }
+ }
+
+ pub fn step(&self) -> &WeakModel<WorkflowStep> {
+ &self.step
+ }
+
+ fn render_result(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
+ let step = self.step.upgrade()?;
+ let result = step.read(cx).resolution.as_ref()?;
+ match result {
+ Ok(result) => {
+ Some(
+ v_flex()
+ .child(result.title.clone())
+ .children(result.suggestion_groups.iter().filter_map(
+ |(buffer, suggestion_groups)| {
+ let buffer = buffer.read(cx);
+ let path = buffer.file().map(|f| f.path());
+ let snapshot = buffer.snapshot();
+ v_flex()
+ .mb_2()
+ .border_b_1()
+ .children(path.map(|path| format!("path: {}", path.display())))
+ .children(suggestion_groups.iter().map(|group| {
+ v_flex().pt_2().pl_2().children(
+ group.suggestions.iter().map(|suggestion| {
+ let range = suggestion.range().to_point(&snapshot);
+ v_flex()
+ .children(
+ suggestion.description().map(|desc| {
+ format!("description: {desc}")
+ }),
+ )
+ .child(format!("kind: {}", suggestion.kind()))
+ .children(suggestion.symbol_path().map(
+ |path| format!("symbol path: {}", path.0),
+ ))
+ .child(format!(
+ "lines: {} - {}",
+ range.start.row + 1,
+ range.end.row + 1
+ ))
+ }),
+ )
+ }))
+ .into()
+ },
+ ))
+ .into_any_element(),
+ )
+ }
+ Err(error) => Some(format!("{:?}", error).into_any_element()),
+ }
+ }
+
+ fn step_updated(&mut self, step: Model<WorkflowStep>, cx: &mut ViewContext<Self>) {
+ self.tool_output_buffer.update(cx, |buffer, cx| {
+ let text = step.read(cx).tool_output.clone();
+ buffer.set_text(text, cx);
+ });
+ cx.notify();
+ }
+
+ fn step_released(&mut self, _: &mut WorkflowStep, cx: &mut ViewContext<Self>) {
+ cx.emit(EditorEvent::Closed);
+ }
+
+ fn resolve(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
+ self.step
+ .update(cx, |step, cx| {
+ step.resolve(cx);
+ })
+ .ok();
+ }
+}
+
+fn section_header(
+ name: &'static str,
+ cx: &mut editor::display_map::BlockContext,
+) -> gpui::AnyElement {
+ h_flex()
+ .pl(cx.gutter_dimensions.full_width())
+ .h_11()
+ .w_full()
+ .relative()
+ .gap_1()
+ .child(
+ ButtonLike::new("role")
+ .style(ButtonStyle::Filled)
+ .child(Label::new(name).color(Color::Default)),
+ )
+ .into_any_element()
+}
+
+impl Render for WorkflowStepView {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ div()
+ .key_context("ContextEditor")
+ .on_action(cx.listener(Self::resolve))
+ .flex_grow()
+ .bg(cx.theme().colors().editor_background)
+ .child(self.editor.clone())
+ }
+}
+
+impl EventEmitter<EditorEvent> for WorkflowStepView {}
+
+impl FocusableView for WorkflowStepView {
+ fn focus_handle(&self, cx: &gpui::AppContext) -> gpui::FocusHandle {
+ self.editor.read(cx).focus_handle(cx)
+ }
+}
+
+impl Item for WorkflowStepView {
+ type Event = EditorEvent;
+
+ fn tab_content_text(&self, cx: &WindowContext) -> Option<SharedString> {
+ let step = self.step.upgrade()?.read(cx);
+ let context = step.context.upgrade()?.read(cx);
+ let buffer = context.buffer().read(cx);
+ let index = context
+ .workflow_step_index_for_range(&step.context_buffer_range, buffer)
+ .ok()?
+ + 1;
+ Some(format!("Step {index}").into())
+ }
+
+ fn tab_icon(&self, _cx: &WindowContext) -> Option<ui::Icon> {
+ Some(Icon::new(IconName::Pencil))
+ }
+
+ fn to_item_events(event: &Self::Event, mut f: impl FnMut(item::ItemEvent)) {
+ match event {
+ EditorEvent::Edited { .. } => {
+ f(item::ItemEvent::Edit);
+ }
+ EditorEvent::TitleChanged => {
+ f(item::ItemEvent::UpdateTab);
+ }
+ EditorEvent::Closed => f(item::ItemEvent::CloseItem),
+ _ => {}
+ }
+ }
+
+ fn tab_tooltip_text(&self, _cx: &AppContext) -> Option<SharedString> {
+ None
+ }
+
+ fn as_searchable(&self, _handle: &View<Self>) -> Option<Box<dyn SearchableItemHandle>> {
+ None
+ }
+
+ fn set_nav_history(&mut self, nav_history: pane::ItemNavHistory, cx: &mut ViewContext<Self>) {
+ self.editor.update(cx, |editor, cx| {
+ Item::set_nav_history(editor, nav_history, cx)
+ })
+ }
+
+ fn navigate(&mut self, data: Box<dyn std::any::Any>, cx: &mut ViewContext<Self>) -> bool {
+ self.editor
+ .update(cx, |editor, cx| Item::navigate(editor, data, cx))
+ }
+
+ fn deactivated(&mut self, cx: &mut ViewContext<Self>) {
+ self.editor
+ .update(cx, |editor, cx| Item::deactivated(editor, cx))
+ }
+}
@@ -15,6 +15,35 @@ pub fn init(cx: &mut AppContext) {
SlashCommandRegistry::default_global(cx);
}
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum AfterCompletion {
+ /// Run the command
+ Run,
+ /// Continue composing the current argument, doesn't add a space
+ Compose,
+ /// Continue the command composition, adds a space
+ Continue,
+}
+
+impl From<bool> for AfterCompletion {
+ fn from(value: bool) -> Self {
+ if value {
+ AfterCompletion::Run
+ } else {
+ AfterCompletion::Continue
+ }
+ }
+}
+
+impl AfterCompletion {
+ pub fn run(&self) -> bool {
+ match self {
+ AfterCompletion::Run => true,
+ AfterCompletion::Compose | AfterCompletion::Continue => false,
+ }
+ }
+}
+
#[derive(Debug)]
pub struct ArgumentCompletion {
/// The label to display for this completion.
@@ -22,7 +51,7 @@ pub struct ArgumentCompletion {
/// The new text that should be inserted into the command when this completion is accepted.
pub new_text: String,
/// Whether the command should be run when accepting this completion.
- pub run_command: bool,
+ pub after_completion: AfterCompletion,
/// Whether to replace the all arguments, or whether to treat this as an independent argument.
pub replace_previous_arguments: bool,
}
@@ -42,6 +71,9 @@ pub trait SlashCommand: 'static + Send + Sync {
cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>>;
fn requires_argument(&self) -> bool;
+ fn accepts_arguments(&self) -> bool {
+ self.requires_argument()
+ }
fn run(
self: Arc<Self>,
arguments: &[String],
@@ -58,10 +58,14 @@ impl SlashCommandRegistry {
/// Unregisters the provided [`SlashCommand`].
pub fn unregister_command(&self, command: impl SlashCommand) {
+ self.unregister_command_by_name(command.name().as_str())
+ }
+
+ /// Unregisters the command with the given name.
+ pub fn unregister_command_by_name(&self, command_name: &str) {
let mut state = self.state.write();
- let command_name: Arc<str> = command.name().into();
- state.featured_commands.remove(&command_name);
- state.commands.remove(&command_name);
+ state.featured_commands.remove(command_name);
+ state.commands.remove(command_name);
}
/// Returns the names of registered [`SlashCommand`]s.
@@ -32,6 +32,7 @@ fn add_vars(env_content: String, vars: &mut Vec<(String, String)>) -> Result<()>
toml::Value::String(value) => value,
toml::Value::Integer(value) => value.to_string(),
toml::Value::Float(value) => value.to_string(),
+ toml::Value::Boolean(value) => value.to_string(),
_ => panic!("unsupported TOML value in .env.toml for key {}", key),
};
vars.push((key, value));
@@ -138,7 +138,10 @@ async fn validate_api_token<B>(mut req: Request<B>, next: Next<B>) -> impl IntoR
));
}
- tracing::Span::current().record("authn.jti", &claims.jti);
+ tracing::Span::current()
+ .record("user_id", claims.user_id)
+ .record("login", claims.github_user_login.clone())
+ .record("authn.jti", &claims.jti);
req.extensions_mut().insert(claims);
Ok::<_, Error>(next.run(req).await.into_response())
@@ -26,7 +26,7 @@ fn authorize_access_to_model(
}
match (provider, model) {
- (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
+ (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
Ok(())
}
_ => Err(Error::http(
@@ -240,14 +240,14 @@ mod tests {
(
Plan::ZedPro,
LanguageModelProvider::Anthropic,
- "claude-3.5-sonnet",
+ "claude-3-5-sonnet",
true,
),
// Free plan should have access to claude-3.5-sonnet
(
Plan::Free,
LanguageModelProvider::Anthropic,
- "claude-3.5-sonnet",
+ "claude-3-5-sonnet",
true,
),
// Pro plan should NOT have access to other Anthropic models
@@ -303,7 +303,7 @@ mod tests {
// Staff should have access to all models
let test_cases = vec![
- (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
+ (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
(LanguageModelProvider::Anthropic, "claude-2"),
(LanguageModelProvider::Anthropic, "claude-123-agi"),
(LanguageModelProvider::OpenAi, "gpt-4"),
@@ -13,6 +13,12 @@ pub struct LlmTokenClaims {
pub exp: u64,
pub jti: String,
pub user_id: u64,
+ // This field is temporarily optional so it can be added
+ // in a backwards-compatible way. We can make it required
+ // once all of the LLM tokens have cycled (~1 hour after
+ // this change has been deployed).
+ #[serde(default)]
+ pub github_user_login: Option<String>,
pub is_staff: bool,
pub plan: rpc::proto::Plan,
}
@@ -22,6 +28,7 @@ const LLM_TOKEN_LIFETIME: Duration = Duration::from_secs(60 * 60);
impl LlmTokenClaims {
pub fn create(
user_id: UserId,
+ github_user_login: String,
is_staff: bool,
plan: rpc::proto::Plan,
config: &Config,
@@ -37,6 +44,7 @@ impl LlmTokenClaims {
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user_id.to_proto(),
+ github_user_login: Some(github_user_login),
is_staff,
plan,
};
@@ -150,6 +150,8 @@ async fn main() -> Result<()> {
"http_request",
method = ?request.method(),
matched_path,
+ user_id = tracing::field::Empty,
+ login = tracing::field::Empty,
authn.jti = tracing::field::Empty
)
})
@@ -71,7 +71,7 @@ use std::{
time::{Duration, Instant},
};
use time::OffsetDateTime;
-use tokio::sync::{watch, Semaphore};
+use tokio::sync::{watch, MutexGuard, Semaphore};
use tower::ServiceBuilder;
use tracing::{
field::{self},
@@ -192,7 +192,7 @@ impl Session {
}
}
- pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
+ pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
if self.is_staff() {
return Ok(proto::Plan::ZedPro);
}
@@ -201,7 +201,6 @@ impl Session {
return Ok(proto::Plan::Free);
};
- let db = self.db().await;
if db.has_active_billing_subscription(user_id).await? {
Ok(proto::Plan::ZedPro)
} else {
@@ -3500,7 +3499,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
}
async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
- let plan = session.current_plan().await?;
+ let plan = session.current_plan(session.db().await).await?;
session
.peer
@@ -4503,7 +4502,7 @@ async fn count_language_model_tokens(
};
authorize_access_to_legacy_llm_endpoints(&session).await?;
- let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
};
@@ -4623,7 +4622,7 @@ async fn compute_embeddings(
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_legacy_llm_endpoints(&session).await?;
- let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
};
@@ -4940,11 +4939,11 @@ async fn get_llm_api_token(
if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
Err(anyhow!("account too young"))?
}
-
let token = LlmTokenClaims::create(
user.id,
+ user.github_login.clone(),
session.is_staff(),
- session.current_plan().await?,
+ session.current_plan(db).await?,
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
@@ -314,7 +314,6 @@ impl MessageEditor {
server_id: LanguageServerId(0), // TODO: Make this optional or something?
lsp_completion: Default::default(), // TODO: Make this optional or something?
confirm: None,
- show_new_completions_on_confirm: false,
}
})
.collect()
@@ -0,0 +1,29 @@
+[package]
+name = "context_servers"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/context_servers.rs"
+
+[dependencies]
+anyhow.workspace = true
+collections.workspace = true
+futures.workspace = true
+gpui.workspace = true
+log.workspace = true
+parking_lot.workspace = true
+postage.workspace = true
+schemars.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+smol.workspace = true
+url = { workspace = true, features = ["serde"] }
+util.workspace = true
+workspace.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,432 @@
+use anyhow::{anyhow, Context, Result};
+use collections::HashMap;
+use futures::{channel::oneshot, io::BufWriter, select, AsyncRead, AsyncWrite, FutureExt};
+use gpui::{AsyncAppContext, BackgroundExecutor, Task};
+use parking_lot::Mutex;
+use postage::barrier;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
+use serde_json::{value::RawValue, Value};
+use smol::{
+ channel,
+ io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
+ process::{self, Child},
+};
+use std::{
+ fmt,
+ path::PathBuf,
+ sync::{
+ atomic::{AtomicI32, Ordering::SeqCst},
+ Arc,
+ },
+ time::{Duration, Instant},
+};
+use util::TryFutureExt;
+
+const JSON_RPC_VERSION: &str = "2.0";
+const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
+
+type ResponseHandler = Box<dyn Send + FnOnce(Result<String, Error>)>;
+type NotificationHandler = Box<dyn Send + FnMut(RequestId, Value, AsyncAppContext)>;
+
+#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum RequestId {
+ Int(i32),
+ Str(String),
+}
+
+pub struct Client {
+ server_id: ContextServerId,
+ next_id: AtomicI32,
+ outbound_tx: channel::Sender<String>,
+ name: Arc<str>,
+ notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ #[allow(clippy::type_complexity)]
+ #[allow(dead_code)]
+ io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
+ #[allow(dead_code)]
+ output_done_rx: Mutex<Option<barrier::Receiver>>,
+ executor: BackgroundExecutor,
+ server: Arc<Mutex<Option<Child>>>,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
+#[repr(transparent)]
+pub struct ContextServerId(pub String);
+
+#[derive(Serialize, Deserialize)]
+struct Request<'a, T> {
+ jsonrpc: &'static str,
+ id: RequestId,
+ method: &'a str,
+ params: T,
+}
+
+#[derive(Serialize, Deserialize)]
+struct AnyResponse<'a> {
+ jsonrpc: &'a str,
+ id: RequestId,
+ #[serde(default)]
+ error: Option<Error>,
+ #[serde(borrow)]
+ result: Option<&'a RawValue>,
+}
+
+#[derive(Deserialize)]
+#[allow(dead_code)]
+struct Response<T> {
+ jsonrpc: &'static str,
+ id: RequestId,
+ #[serde(flatten)]
+ value: CspResult<T>,
+}
+
+#[derive(Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum CspResult<T> {
+ #[serde(rename = "result")]
+ Ok(Option<T>),
+ #[allow(dead_code)]
+ Error(Option<Error>),
+}
+
+#[derive(Serialize, Deserialize)]
+struct Notification<'a, T> {
+ jsonrpc: &'static str,
+ id: RequestId,
+ #[serde(borrow)]
+ method: &'a str,
+ params: T,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+struct AnyNotification<'a> {
+ jsonrpc: &'a str,
+ id: RequestId,
+ method: String,
+ #[serde(default)]
+ params: Option<Value>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct Error {
+ message: String,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct ModelContextServerBinary {
+ pub executable: PathBuf,
+ pub args: Vec<String>,
+ pub env: Option<HashMap<String, String>>,
+}
+
+impl Client {
+ /// Creates a new Client instance for a context server.
+ ///
+ /// This function initializes a new Client by spawning a child process for the context server,
+ /// setting up communication channels, and initializing handlers for input/output operations.
+ /// It takes a server ID, binary information, and an async app context as input.
+ pub fn new(
+ server_id: ContextServerId,
+ binary: ModelContextServerBinary,
+ cx: AsyncAppContext,
+ ) -> Result<Self> {
+ log::info!(
+ "starting context server (executable={:?}, args={:?})",
+ binary.executable,
+ &binary.args
+ );
+
+ let mut command = process::Command::new(&binary.executable);
+ command
+ .args(&binary.args)
+ .envs(binary.env.unwrap_or_default())
+ .stdin(std::process::Stdio::piped())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::piped())
+ .kill_on_drop(true);
+
+ let mut server = command.spawn().with_context(|| {
+ format!(
+ "failed to spawn command. (path={:?}, args={:?})",
+ binary.executable, &binary.args
+ )
+ })?;
+
+ let stdin = server.stdin.take().unwrap();
+ let stdout = server.stdout.take().unwrap();
+ let stderr = server.stderr.take().unwrap();
+
+ let (outbound_tx, outbound_rx) = channel::unbounded::<String>();
+ let (output_done_tx, output_done_rx) = barrier::channel();
+
+ let notification_handlers =
+ Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
+ let response_handlers =
+ Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
+
+ let stdout_input_task = cx.spawn({
+ let notification_handlers = notification_handlers.clone();
+ let response_handlers = response_handlers.clone();
+ move |cx| {
+ Self::handle_input(stdout, notification_handlers, response_handlers, cx).log_err()
+ }
+ });
+ let stderr_input_task = cx.spawn(|_| Self::handle_stderr(stderr).log_err());
+ let input_task = cx.spawn(|_| async move {
+ let (stdout, stderr) = futures::join!(stdout_input_task, stderr_input_task);
+ stdout.or(stderr)
+ });
+ let output_task = cx.background_executor().spawn({
+ Self::handle_output(
+ stdin,
+ outbound_rx,
+ output_done_tx,
+ response_handlers.clone(),
+ )
+ .log_err()
+ });
+
+ let mut context_server = Self {
+ server_id,
+ notification_handlers,
+ response_handlers,
+ name: "".into(),
+ next_id: Default::default(),
+ outbound_tx,
+ executor: cx.background_executor().clone(),
+ io_tasks: Mutex::new(Some((input_task, output_task))),
+ output_done_rx: Mutex::new(Some(output_done_rx)),
+ server: Arc::new(Mutex::new(Some(server))),
+ };
+
+ if let Some(name) = binary.executable.file_name() {
+ context_server.name = name.to_string_lossy().into();
+ }
+
+ Ok(context_server)
+ }
+
+ /// Handles input from the server's stdout.
+ ///
+ /// This function continuously reads lines from the provided stdout stream,
+ /// parses them as JSON-RPC responses or notifications, and dispatches them
+ /// to the appropriate handlers. It processes both responses (which are matched
+ /// to pending requests) and notifications (which trigger registered handlers).
+ async fn handle_input<Stdout>(
+ stdout: Stdout,
+ notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
+ response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ cx: AsyncAppContext,
+ ) -> anyhow::Result<()>
+ where
+ Stdout: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stdout = BufReader::new(stdout);
+ let mut buffer = String::new();
+
+ loop {
+ buffer.clear();
+ if stdout.read_line(&mut buffer).await? == 0 {
+ return Ok(());
+ }
+
+ let content = buffer.trim();
+
+ if !content.is_empty() {
+ if let Ok(response) = serde_json::from_str::<AnyResponse>(&content) {
+ if let Some(handlers) = response_handlers.lock().as_mut() {
+ if let Some(handler) = handlers.remove(&response.id) {
+ handler(Ok(content.to_string()));
+ }
+ }
+ } else if let Ok(notification) = serde_json::from_str::<AnyNotification>(&content) {
+ let mut notification_handlers = notification_handlers.lock();
+ if let Some(handler) =
+ notification_handlers.get_mut(notification.method.as_str())
+ {
+ handler(
+ notification.id,
+ notification.params.unwrap_or(Value::Null),
+ cx.clone(),
+ );
+ }
+ }
+ }
+
+ smol::future::yield_now().await;
+ }
+ }
+
+ /// Handles the stderr output from the context server.
+ /// Continuously reads and logs any error messages from the server.
+ async fn handle_stderr<Stderr>(stderr: Stderr) -> anyhow::Result<()>
+ where
+ Stderr: AsyncRead + Unpin + Send + 'static,
+ {
+ let mut stderr = BufReader::new(stderr);
+ let mut buffer = String::new();
+
+ loop {
+ buffer.clear();
+ if stderr.read_line(&mut buffer).await? == 0 {
+ return Ok(());
+ }
+ log::warn!("context server stderr: {}", buffer.trim());
+ smol::future::yield_now().await;
+ }
+ }
+
+ /// Handles the output to the context server's stdin.
+ /// This function continuously receives messages from the outbound channel,
+ /// writes them to the server's stdin, and manages the lifecycle of response handlers.
+ async fn handle_output<Stdin>(
+ stdin: Stdin,
+ outbound_rx: channel::Receiver<String>,
+ output_done_tx: barrier::Sender,
+ response_handlers: Arc<Mutex<Option<HashMap<RequestId, ResponseHandler>>>>,
+ ) -> anyhow::Result<()>
+ where
+ Stdin: AsyncWrite + Unpin + Send + 'static,
+ {
+ let mut stdin = BufWriter::new(stdin);
+ let _clear_response_handlers = util::defer({
+ let response_handlers = response_handlers.clone();
+ move || {
+ response_handlers.lock().take();
+ }
+ });
+ while let Ok(message) = outbound_rx.recv().await {
+ log::trace!("outgoing message: {}", message);
+
+ stdin.write_all(message.as_bytes()).await?;
+ stdin.write_all(b"\n").await?;
+ stdin.flush().await?;
+ }
+ drop(output_done_tx);
+ Ok(())
+ }
+
+ /// Sends a JSON-RPC request to the context server and waits for a response.
+ /// This function handles serialization, deserialization, timeout, and error handling.
+ pub async fn request<T: DeserializeOwned>(
+ &self,
+ method: &str,
+ params: impl Serialize,
+ ) -> Result<T> {
+ let id = self.next_id.fetch_add(1, SeqCst);
+ let request = serde_json::to_string(&Request {
+ jsonrpc: JSON_RPC_VERSION,
+ id: RequestId::Int(id),
+ method,
+ params,
+ })
+ .unwrap();
+
+ let (tx, rx) = oneshot::channel();
+ let handle_response = self
+ .response_handlers
+ .lock()
+ .as_mut()
+ .ok_or_else(|| anyhow!("server shut down"))
+ .map(|handlers| {
+ handlers.insert(
+ RequestId::Int(id),
+ Box::new(move |result| {
+ let _ = tx.send(result);
+ }),
+ );
+ });
+
+ let send = self
+ .outbound_tx
+ .try_send(request)
+ .context("failed to write to context server's stdin");
+
+ let executor = self.executor.clone();
+ let started = Instant::now();
+ handle_response?;
+ send?;
+
+ let mut timeout = executor.timer(REQUEST_TIMEOUT).fuse();
+ select! {
+ response = rx.fuse() => {
+ let elapsed = started.elapsed();
+ log::trace!("took {elapsed:?} to receive response to {method:?} id {id}");
+ match response? {
+ Ok(response) => {
+ let parsed: AnyResponse = serde_json::from_str(&response)?;
+ if let Some(error) = parsed.error {
+ Err(anyhow!(error.message))
+ } else if let Some(result) = parsed.result {
+ Ok(serde_json::from_str(result.get())?)
+ } else {
+ Err(anyhow!("Invalid response: no result or error"))
+ }
+ }
+ Err(_) => anyhow::bail!("cancelled")
+ }
+ }
+ _ = timeout => {
+ log::error!("cancelled csp request task for {method:?} id {id} which took over {:?}", REQUEST_TIMEOUT);
+ anyhow::bail!("Context server request timeout");
+ }
+ }
+ }
+
+ /// Sends a notification to the context server without expecting a response.
+ /// This function serializes the notification and sends it through the outbound channel.
+ pub fn notify(&self, method: &str, params: impl Serialize) -> Result<()> {
+ let id = self.next_id.fetch_add(1, SeqCst);
+ let notification = serde_json::to_string(&Notification {
+ jsonrpc: JSON_RPC_VERSION,
+ id: RequestId::Int(id),
+ method,
+ params,
+ })
+ .unwrap();
+ self.outbound_tx.try_send(notification)?;
+ Ok(())
+ }
+
+ pub fn on_notification<F>(&self, method: &'static str, mut f: F)
+ where
+ F: 'static + Send + FnMut(Value, AsyncAppContext),
+ {
+ self.notification_handlers
+ .lock()
+ .insert(method, Box::new(move |_, params, cx| f(params, cx)));
+ }
+
+ pub fn name(&self) -> &str {
+ &self.name
+ }
+
+ pub fn server_id(&self) -> ContextServerId {
+ self.server_id.clone()
+ }
+}
+
+impl Drop for Client {
+ fn drop(&mut self) {
+ if let Some(mut server) = self.server.lock().take() {
+ let _ = server.kill();
+ }
+ }
+}
+
+impl fmt::Display for ContextServerId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl fmt::Debug for Client {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Context Server Client")
+ .field("id", &self.server_id.0)
+ .field("name", &self.name)
+ .finish_non_exhaustive()
+ }
+}
@@ -0,0 +1,36 @@
+use gpui::{actions, AppContext, Context, ViewContext};
+use log;
+use manager::ContextServerManager;
+use workspace::Workspace;
+
+pub mod client;
+pub mod manager;
+pub mod protocol;
+mod registry;
+pub mod types;
+
+pub use registry::*;
+
+actions!(context_servers, [Restart]);
+
+pub fn init(cx: &mut AppContext) {
+ log::info!("initializing context server client");
+ manager::init(cx);
+ ContextServerRegistry::register(cx);
+
+ cx.observe_new_views(
+ |workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
+ workspace.register_action(restart_servers);
+ },
+ )
+ .detach();
+}
+
+fn restart_servers(_workspace: &mut Workspace, _action: &Restart, cx: &mut ViewContext<Workspace>) {
+ let model = ContextServerManager::global(&cx);
+ cx.update_model(&model, |manager, cx| {
+ for server in manager.servers() {
+ manager.restart_server(&server.id, cx).detach();
+ }
+ });
+}
@@ -0,0 +1,278 @@
+//! This module implements a context server management system for Zed.
+//!
+//! It provides functionality to:
+//! - Define and load context server settings
+//! - Manage individual context servers (start, stop, restart)
+//! - Maintain a global manager for all context servers
+//!
+//! Key components:
+//! - `ContextServerSettings`: Defines the structure for server configurations
+//! - `ContextServer`: Represents an individual context server
+//! - `ContextServerManager`: Manages multiple context servers
+//! - `GlobalContextServerManager`: Provides global access to the ContextServerManager
+//!
+//! The module also includes initialization logic to set up the context server system
+//! and react to changes in settings.
+
+use collections::{HashMap, HashSet};
+use gpui::{AppContext, AsyncAppContext, Context, EventEmitter, Global, Model, ModelContext, Task};
+use log;
+use parking_lot::RwLock;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use settings::{Settings, SettingsSources, SettingsStore};
+use std::path::Path;
+use std::sync::Arc;
+
+use crate::{
+ client::{self, Client},
+ types,
+};
+
+#[derive(Deserialize, Serialize, Default, Clone, PartialEq, Eq, JsonSchema, Debug)]
+pub struct ContextServerSettings {
+ pub servers: Vec<ServerConfig>,
+}
+
+#[derive(Deserialize, Serialize, Clone, PartialEq, Eq, JsonSchema, Debug)]
+pub struct ServerConfig {
+ pub id: String,
+ pub executable: String,
+ pub args: Vec<String>,
+}
+
+impl Settings for ContextServerSettings {
+ const KEY: Option<&'static str> = Some("experimental.context_servers");
+
+ type FileContent = Self;
+
+ fn load(
+ sources: SettingsSources<Self::FileContent>,
+ _: &mut gpui::AppContext,
+ ) -> anyhow::Result<Self> {
+ sources.json_merge()
+ }
+}
+
+pub struct ContextServer {
+ pub id: String,
+ pub config: ServerConfig,
+ pub client: RwLock<Option<Arc<crate::protocol::InitializedContextServerProtocol>>>,
+}
+
+impl ContextServer {
+ fn new(config: ServerConfig) -> Self {
+ Self {
+ id: config.id.clone(),
+ config,
+ client: RwLock::new(None),
+ }
+ }
+
+ async fn start(&self, cx: &AsyncAppContext) -> anyhow::Result<()> {
+ log::info!("starting context server {}", self.config.id);
+ let client = Client::new(
+ client::ContextServerId(self.config.id.clone()),
+ client::ModelContextServerBinary {
+ executable: Path::new(&self.config.executable).to_path_buf(),
+ args: self.config.args.clone(),
+ env: None,
+ },
+ cx.clone(),
+ )?;
+
+ let protocol = crate::protocol::ModelContextProtocol::new(client);
+ let client_info = types::EntityInfo {
+ name: "Zed".to_string(),
+ version: env!("CARGO_PKG_VERSION").to_string(),
+ };
+ let initialized_protocol = protocol.initialize(client_info).await?;
+
+ log::debug!(
+ "context server {} initialized: {:?}",
+ self.config.id,
+ initialized_protocol.initialize,
+ );
+
+ *self.client.write() = Some(Arc::new(initialized_protocol));
+ Ok(())
+ }
+
+ async fn stop(&self) -> anyhow::Result<()> {
+ let mut client = self.client.write();
+ if let Some(protocol) = client.take() {
+ drop(protocol);
+ }
+ Ok(())
+ }
+}
+
+/// A Context server manager manages the starting and stopping
+/// of all servers. To obtain a server to interact with, a crate
+/// must go through the `GlobalContextServerManager` which holds
+/// a model to the ContextServerManager.
+pub struct ContextServerManager {
+ servers: HashMap<String, Arc<ContextServer>>,
+ pending_servers: HashSet<String>,
+}
+
+pub enum Event {
+ ServerStarted { server_id: String },
+ ServerStopped { server_id: String },
+}
+
+impl Global for ContextServerManager {}
+impl EventEmitter<Event> for ContextServerManager {}
+
+impl ContextServerManager {
+ pub fn new() -> Self {
+ Self {
+ servers: HashMap::default(),
+ pending_servers: HashSet::default(),
+ }
+ }
+ pub fn global(cx: &AppContext) -> Model<Self> {
+ cx.global::<GlobalContextServerManager>().0.clone()
+ }
+
+ pub fn add_server(
+ &mut self,
+ config: ServerConfig,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let server_id = config.id.clone();
+ let server_id2 = config.id.clone();
+
+ if self.servers.contains_key(&server_id) || self.pending_servers.contains(&server_id) {
+ return Task::ready(Ok(()));
+ }
+
+ let task = cx.spawn(|this, mut cx| async move {
+ let server = Arc::new(ContextServer::new(config));
+ server.start(&cx).await?;
+ this.update(&mut cx, |this, cx| {
+ this.servers.insert(server_id.clone(), server);
+ this.pending_servers.remove(&server_id);
+ cx.emit(Event::ServerStarted {
+ server_id: server_id.clone(),
+ });
+ })?;
+ Ok(())
+ });
+
+ self.pending_servers.insert(server_id2);
+ task
+ }
+
+ pub fn get_server(&self, id: &str) -> Option<Arc<ContextServer>> {
+ self.servers.get(id).cloned()
+ }
+
+ pub fn remove_server(
+ &mut self,
+ id: &str,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let id = id.to_string();
+ cx.spawn(|this, mut cx| async move {
+ if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
+ server.stop().await?;
+ }
+ this.update(&mut cx, |this, cx| {
+ this.pending_servers.remove(&id);
+ cx.emit(Event::ServerStopped {
+ server_id: id.clone(),
+ })
+ })?;
+ Ok(())
+ })
+ }
+
+ pub fn restart_server(
+ &mut self,
+ id: &str,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<anyhow::Result<()>> {
+ let id = id.to_string();
+ cx.spawn(|this, mut cx| async move {
+ if let Some(server) = this.update(&mut cx, |this, _cx| this.servers.remove(&id))? {
+ server.stop().await?;
+ let config = server.config.clone();
+ let new_server = Arc::new(ContextServer::new(config));
+ new_server.start(&cx).await?;
+ this.update(&mut cx, |this, cx| {
+ this.servers.insert(id.clone(), new_server);
+ cx.emit(Event::ServerStopped {
+ server_id: id.clone(),
+ });
+ cx.emit(Event::ServerStarted {
+ server_id: id.clone(),
+ });
+ })?;
+ }
+ Ok(())
+ })
+ }
+
+ pub fn servers(&self) -> Vec<Arc<ContextServer>> {
+ self.servers.values().cloned().collect()
+ }
+
+ pub fn model(cx: &mut AppContext) -> Model<Self> {
+ cx.new_model(|_cx| ContextServerManager::new())
+ }
+}
+
+pub struct GlobalContextServerManager(Model<ContextServerManager>);
+impl Global for GlobalContextServerManager {}
+
+impl GlobalContextServerManager {
+ fn register(cx: &mut AppContext) {
+ let model = ContextServerManager::model(cx);
+ cx.set_global(Self(model));
+ }
+}
+
+pub fn init(cx: &mut AppContext) {
+ ContextServerSettings::register(cx);
+ GlobalContextServerManager::register(cx);
+ cx.observe_global::<SettingsStore>(|cx| {
+ let manager = ContextServerManager::global(cx);
+ cx.update_model(&manager, |manager, cx| {
+ let settings = ContextServerSettings::get_global(cx);
+ let current_servers: HashMap<String, ServerConfig> = manager
+ .servers()
+ .into_iter()
+ .map(|server| (server.id.clone(), server.config.clone()))
+ .collect();
+
+ let new_servers = settings
+ .servers
+ .iter()
+ .map(|config| (config.id.clone(), config.clone()))
+ .collect::<HashMap<_, _>>();
+
+ let servers_to_add = new_servers
+ .values()
+ .filter(|config| !current_servers.contains_key(&config.id))
+ .cloned()
+ .collect::<Vec<_>>();
+
+ let servers_to_remove = current_servers
+ .keys()
+ .filter(|id| !new_servers.contains_key(*id))
+ .cloned()
+ .collect::<Vec<_>>();
+
+ log::trace!("servers_to_add={:?}", servers_to_add);
+ for config in servers_to_add {
+ manager.add_server(config, cx).detach();
+ }
+
+ for id in servers_to_remove {
+ manager.remove_server(&id, cx).detach();
+ }
+ })
+ })
+ .detach();
+}
@@ -0,0 +1,140 @@
+//! This module implements parts of the Model Context Protocol.
+//!
+//! It handles the lifecycle messages, and provides a general interface to
+//! interacting with an MCP server. It uses the generic JSON-RPC client to
+//! read/write messages and the types from types.rs for serialization/deserialization
+//! of messages.
+
+use anyhow::Result;
+use collections::HashMap;
+
+use crate::client::Client;
+use crate::types;
+
+pub use types::PromptInfo;
+
+const PROTOCOL_VERSION: u32 = 1;
+
+pub struct ModelContextProtocol {
+ inner: Client,
+}
+
+impl ModelContextProtocol {
+ pub fn new(inner: Client) -> Self {
+ Self { inner }
+ }
+
+ pub async fn initialize(
+ self,
+ client_info: types::EntityInfo,
+ ) -> Result<InitializedContextServerProtocol> {
+ let params = types::InitializeParams {
+ protocol_version: PROTOCOL_VERSION,
+ capabilities: types::ClientCapabilities {
+ experimental: None,
+ sampling: None,
+ },
+ client_info,
+ };
+
+ let response: types::InitializeResponse = self
+ .inner
+ .request(types::RequestType::Initialize.as_str(), params)
+ .await?;
+
+ log::trace!("mcp server info {:?}", response.server_info);
+
+ self.inner.notify(
+ types::NotificationType::Initialized.as_str(),
+ serde_json::json!({}),
+ )?;
+
+ let initialized_protocol = InitializedContextServerProtocol {
+ inner: self.inner,
+ initialize: response,
+ };
+
+ Ok(initialized_protocol)
+ }
+}
+
+pub struct InitializedContextServerProtocol {
+ inner: Client,
+ pub initialize: types::InitializeResponse,
+}
+
+#[derive(Debug, PartialEq, Clone, Copy)]
+pub enum ServerCapability {
+ Experimental,
+ Logging,
+ Prompts,
+ Resources,
+ Tools,
+}
+
+impl InitializedContextServerProtocol {
+ /// Check if the server supports a specific capability
+ pub fn capable(&self, capability: ServerCapability) -> bool {
+ match capability {
+ ServerCapability::Experimental => self.initialize.capabilities.experimental.is_some(),
+ ServerCapability::Logging => self.initialize.capabilities.logging.is_some(),
+ ServerCapability::Prompts => self.initialize.capabilities.prompts.is_some(),
+ ServerCapability::Resources => self.initialize.capabilities.resources.is_some(),
+ ServerCapability::Tools => self.initialize.capabilities.tools.is_some(),
+ }
+ }
+
+ fn check_capability(&self, capability: ServerCapability) -> Result<()> {
+ if self.capable(capability) {
+ Ok(())
+ } else {
+ Err(anyhow::anyhow!(
+ "Server does not support {:?} capability",
+ capability
+ ))
+ }
+ }
+
+ /// List the MCP prompts.
+ pub async fn list_prompts(&self) -> Result<Vec<types::PromptInfo>> {
+ self.check_capability(ServerCapability::Prompts)?;
+
+ let response: types::PromptsListResponse = self
+ .inner
+ .request(types::RequestType::PromptsList.as_str(), ())
+ .await?;
+
+ Ok(response.prompts)
+ }
+
+ /// Executes a prompt with the given arguments and returns the result.
+ pub async fn run_prompt<P: AsRef<str>>(
+ &self,
+ prompt: P,
+ arguments: HashMap<String, String>,
+ ) -> Result<String> {
+ self.check_capability(ServerCapability::Prompts)?;
+
+ let params = types::PromptsGetParams {
+ name: prompt.as_ref().to_string(),
+ arguments: Some(arguments),
+ };
+
+ let response: types::PromptsGetResponse = self
+ .inner
+ .request(types::RequestType::PromptsGet.as_str(), params)
+ .await?;
+
+ Ok(response.prompt)
+ }
+}
+
+impl InitializedContextServerProtocol {
+ pub async fn request<R: serde::de::DeserializeOwned>(
+ &self,
+ method: &str,
+ params: impl serde::Serialize,
+ ) -> Result<R> {
+ self.inner.request(method, params).await
+ }
+}
@@ -0,0 +1,47 @@
+use std::sync::Arc;
+
+use collections::HashMap;
+use gpui::{AppContext, Global, ReadGlobal};
+use parking_lot::RwLock;
+
+struct GlobalContextServerRegistry(Arc<ContextServerRegistry>);
+
+impl Global for GlobalContextServerRegistry {}
+
+pub struct ContextServerRegistry {
+ registry: RwLock<HashMap<String, Vec<Arc<str>>>>,
+}
+
+impl ContextServerRegistry {
+ pub fn global(cx: &AppContext) -> Arc<Self> {
+ GlobalContextServerRegistry::global(cx).0.clone()
+ }
+
+ pub fn register(cx: &mut AppContext) {
+ cx.set_global(GlobalContextServerRegistry(Arc::new(
+ ContextServerRegistry {
+ registry: RwLock::new(HashMap::default()),
+ },
+ )))
+ }
+
+ pub fn register_command(&self, server_id: String, command_name: &str) {
+ let mut registry = self.registry.write();
+ registry
+ .entry(server_id)
+ .or_default()
+ .push(command_name.into());
+ }
+
+ pub fn unregister_command(&self, server_id: &str, command_name: &str) {
+ let mut registry = self.registry.write();
+ if let Some(commands) = registry.get_mut(server_id) {
+ commands.retain(|name| name.as_ref() != command_name);
+ }
+ }
+
+ pub fn get_commands(&self, server_id: &str) -> Option<Vec<Arc<str>>> {
+ let registry = self.registry.read();
+ registry.get(server_id).cloned()
+ }
+}
@@ -0,0 +1,234 @@
+use collections::HashMap;
+use serde::{Deserialize, Serialize};
+use url::Url;
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum RequestType {
+ Initialize,
+ CallTool,
+ ResourcesUnsubscribe,
+ ResourcesSubscribe,
+ ResourcesRead,
+ ResourcesList,
+ LoggingSetLevel,
+ PromptsGet,
+ PromptsList,
+}
+
+impl RequestType {
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ RequestType::Initialize => "initialize",
+ RequestType::CallTool => "tools/call",
+ RequestType::ResourcesUnsubscribe => "resources/unsubscribe",
+ RequestType::ResourcesSubscribe => "resources/subscribe",
+ RequestType::ResourcesRead => "resources/read",
+ RequestType::ResourcesList => "resources/list",
+ RequestType::LoggingSetLevel => "logging/setLevel",
+ RequestType::PromptsGet => "prompts/get",
+ RequestType::PromptsList => "prompts/list",
+ }
+ }
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct InitializeParams {
+ pub protocol_version: u32,
+ pub capabilities: ClientCapabilities,
+ pub client_info: EntityInfo,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CallToolParams {
+ pub name: String,
+ pub arguments: Option<serde_json::Value>,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesUnsubscribeParams {
+ pub uri: Url,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesSubscribeParams {
+ pub uri: Url,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesReadParams {
+ pub uri: Url,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct LoggingSetLevelParams {
+ pub level: LoggingLevel,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsGetParams {
+ pub name: String,
+ pub arguments: Option<HashMap<String, String>>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct InitializeResponse {
+ pub protocol_version: u32,
+ pub capabilities: ServerCapabilities,
+ pub server_info: EntityInfo,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesReadResponse {
+ pub contents: Vec<ResourceContent>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesListResponse {
+ pub resource_templates: Option<Vec<ResourceTemplate>>,
+ pub resources: Vec<Resource>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsGetResponse {
+ pub prompt: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptsListResponse {
+ pub prompts: Vec<PromptInfo>,
+}
+
+#[derive(Debug, Deserialize, Clone)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptInfo {
+ pub name: String,
+ pub arguments: Option<Vec<PromptArgument>>,
+}
+
+#[derive(Debug, Deserialize, Clone)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptArgument {
+ pub name: String,
+ pub description: Option<String>,
+ pub required: Option<bool>,
+}
+
+// Shared Types
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ClientCapabilities {
+ pub experimental: Option<HashMap<String, serde_json::Value>>,
+ pub sampling: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ServerCapabilities {
+ pub experimental: Option<HashMap<String, serde_json::Value>>,
+ pub logging: Option<HashMap<String, serde_json::Value>>,
+ pub prompts: Option<HashMap<String, serde_json::Value>>,
+ pub resources: Option<ResourcesCapabilities>,
+ pub tools: Option<HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourcesCapabilities {
+ pub subscribe: Option<bool>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Tool {
+ pub name: String,
+ pub description: Option<String>,
+ pub input_schema: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct EntityInfo {
+ pub name: String,
+ pub version: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Resource {
+ pub uri: Url,
+ pub mime_type: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceContent {
+ pub uri: Url,
+ pub mime_type: Option<String>,
+ pub content_type: String,
+ pub text: Option<String>,
+ pub data: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceTemplate {
+ pub uri_template: String,
+ pub name: Option<String>,
+ pub description: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum LoggingLevel {
+ Debug,
+ Info,
+ Warning,
+ Error,
+}
+
+// Client Notifications
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum NotificationType {
+ Initialized,
+ Progress,
+}
+
+impl NotificationType {
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ NotificationType::Initialized => "notifications/initialized",
+ NotificationType::Progress => "notifications/progress",
+ }
+ }
+}
+
+#[derive(Debug, Serialize)]
+#[serde(untagged)]
+pub enum ClientNotification {
+ Initialized,
+ Progress(ProgressParams),
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ProgressParams {
+ pub progress_token: String,
+ pub progress: f64,
+ pub total: Option<f64>,
+}
@@ -4379,11 +4379,11 @@ impl Editor {
this.refresh_inline_completion(true, cx);
});
- if let Some(confirm) = completion.confirm.as_ref() {
- (confirm)(intent, cx);
- }
-
- if completion.show_new_completions_on_confirm {
+ let show_new_completions_on_confirm = completion
+ .confirm
+ .as_ref()
+ .map_or(false, |confirm| confirm(intent, cx));
+ if show_new_completions_on_confirm {
self.show_completions(&ShowCompletions { trigger: None }, cx);
}
@@ -11926,6 +11926,12 @@ impl Editor {
let bounds = self.last_bounds?;
Some(element::gutter_bounds(bounds, self.gutter_dimensions))
}
+
+ pub fn has_active_completions_menu(&self) -> bool {
+ self.context_menu.read().as_ref().map_or(false, |menu| {
+ menu.visible() && matches!(menu, ContextMenu::Completions(_))
+ })
+ }
}
fn hunks_for_selections(
@@ -12141,7 +12147,6 @@ fn snippet_completions(
..Default::default()
},
confirm: None,
- show_new_completions_on_confirm: false,
})
})
.collect()
@@ -777,19 +777,18 @@ fn editor_with_deleted_text(
});
let mut editor = Editor::for_multibuffer(multi_buffer, None, true, cx);
- editor.soft_wrap_mode_override = Some(language::language_settings::SoftWrap::None);
- editor.show_wrap_guides = Some(false);
- editor.show_gutter = false;
+ editor.set_soft_wrap_mode(language::language_settings::SoftWrap::None, cx);
+ editor.set_show_wrap_guides(false, cx);
+ editor.set_show_gutter(false, cx);
editor.scroll_manager.set_forbid_vertical_scroll(true);
editor.set_read_only(true);
-
- let editor_snapshot = editor.snapshot(cx);
- let start = editor_snapshot.buffer_snapshot.anchor_before(0);
- let end = editor_snapshot
- .buffer_snapshot
- .anchor_after(editor.buffer.read(cx).len(cx));
-
- editor.highlight_rows::<DiffRowHighlight>(start..=end, Some(deleted_color), false, cx);
+ editor.set_show_inline_completions(false);
+ editor.highlight_rows::<DiffRowHighlight>(
+ Anchor::min()..=Anchor::max(),
+ Some(deleted_color),
+ false,
+ cx,
+ );
let subscription_editor = parent_editor.clone();
editor._subscriptions.extend([
@@ -137,7 +137,6 @@ impl LanguageServerManifestEntry {
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct SlashCommandManifestEntry {
pub description: String,
- pub tooltip_text: String,
pub requires_argument: bool,
}
@@ -67,7 +67,7 @@ impl SlashCommand for ExtensionSlashCommand {
label: completion.label.into(),
new_text: completion.new_text,
replace_previous_arguments: false,
- run_command: completion.run_command,
+ after_completion: completion.run_command.into(),
})
.collect(),
)
@@ -1216,7 +1216,10 @@ impl ExtensionStore {
command: crate::wit::SlashCommand {
name: slash_command_name.to_string(),
description: slash_command.description.to_string(),
- tooltip_text: slash_command.tooltip_text.to_string(),
+ // We don't currently expose this as a configurable option, as it currently drives
+ // the `menu_text` on the `SlashCommand` trait, which is not used for slash commands
+ // defined in extensions, as they are not able to be added to the menu.
+ tooltip_text: String::new(),
requires_argument: slash_command.requires_argument,
},
extension: wasm_extension.clone(),
@@ -8,9 +8,6 @@ keywords = ["zed", "extension"]
edition = "2021"
license = "Apache-2.0"
-# We'll publish v0.1.0 after the release on Wednesday (2024-08-14).
-publish = false
-
[lints]
workspace = true
@@ -0,0 +1,12 @@
+# Pending Changes
+
+This is a list of pending changes to the Zed extension API that require a breaking change.
+
+This list should be updated as we notice things that should be changed so that we can batch them up in a single release.
+
+## vNext
+
+### Slash Commands
+
+- Rename `SlashCommand.tooltip_text` to `SlashCommand.menu_text`
+ - We may even want to remove it entirely, as right now this is only used for featured slash commands, and slash commands defined by extensions aren't currently able to be featured.
@@ -351,10 +351,13 @@ impl Asset for ImageAsset {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
+ let mut body = String::from_utf8_lossy(&body).into_owned();
+ let first_line = body.lines().next().unwrap_or("").trim_end();
+ body.truncate(first_line.len());
return Err(ImageCacheError::BadStatus {
uri,
status: response.status(),
- body: String::from_utf8_lossy(&body).into_owned(),
+ body,
});
}
body
@@ -71,7 +71,7 @@ pub use language_registry::{
PendingLanguageServer, QUERY_FILENAME_PREFIXES,
};
pub use lsp::LanguageServerId;
-pub use outline::{render_item, Outline, OutlineItem};
+pub use outline::*;
pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer};
pub use text::{AnchorRangeExt, LineEnding};
pub use tree_sitter::{Node, Parser, Tree, TreeCursor};
@@ -25,6 +25,9 @@ pub struct OutlineItem<T> {
pub annotation_range: Option<Range<T>>,
}
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct SymbolPath(pub String);
+
impl<T: ToPoint> OutlineItem<T> {
/// Converts to an equivalent outline item, but with parameterized over Points.
pub fn to_point(&self, buffer: &BufferSnapshot) -> OutlineItem<Point> {
@@ -85,7 +88,7 @@ impl<T> Outline<T> {
}
/// Find the most similar symbol to the provided query using normalized Levenshtein distance.
- pub fn find_most_similar(&self, query: &str) -> Option<&OutlineItem<T>> {
+ pub fn find_most_similar(&self, query: &str) -> Option<(SymbolPath, &OutlineItem<T>)> {
const SIMILARITY_THRESHOLD: f64 = 0.6;
let (position, similarity) = self
@@ -99,8 +102,10 @@ impl<T> Outline<T> {
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())?;
if similarity >= SIMILARITY_THRESHOLD {
- let item = self.items.get(position)?;
- Some(item)
+ self.path_candidates
+ .get(position)
+ .map(|candidate| SymbolPath(candidate.string.clone()))
+ .zip(self.items.get(position))
} else {
None
}
@@ -250,15 +255,15 @@ mod tests {
]);
assert_eq!(
outline.find_most_similar("pub fn process"),
- Some(&outline.items[0])
+ Some((SymbolPath("fn process".into()), &outline.items[0]))
);
assert_eq!(
outline.find_most_similar("async fn process"),
- Some(&outline.items[0])
+ Some((SymbolPath("fn process".into()), &outline.items[0])),
);
assert_eq!(
outline.find_most_similar("struct Processor"),
- Some(&outline.items[1])
+ Some((SymbolPath("struct DataProcessor".into()), &outline.items[1]))
);
assert_eq!(outline.find_most_similar("struct User"), None);
assert_eq!(outline.find_most_similar("struct"), None);
@@ -8,7 +8,7 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
-use futures::{future::BoxFuture, stream::BoxStream};
+use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
@@ -20,7 +20,7 @@ pub use registry::*;
pub use request::*;
pub use role::*;
use schemars::JsonSchema;
-use serde::de::DeserializeOwned;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{future::Future, sync::Arc};
use ui::IconName;
@@ -43,6 +43,14 @@ pub enum LanguageModelAvailability {
RequiresPlan(Plan),
}
+/// Configuration for caching language model messages.
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct LanguageModelCacheConfiguration {
+ pub max_cache_anchors: usize,
+ pub should_speculate: bool,
+ pub min_total_token: usize,
+}
+
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
@@ -56,6 +64,9 @@ pub trait LanguageModel: Send + Sync {
}
fn max_token_count(&self) -> usize;
+ fn max_output_tokens(&self) -> Option<u32> {
+ None
+ }
fn count_tokens(
&self,
@@ -76,7 +87,11 @@ pub trait LanguageModel: Send + Sync {
description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>>;
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+
+ fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+ None
+ }
#[cfg(any(test, feature = "test-support"))]
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
@@ -92,12 +107,23 @@ impl dyn LanguageModel {
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
- let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
+ let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {
- let response = request.await?;
- Ok(serde_json::from_value(response)?)
+ let stream = stream.await?;
+ let response = stream.try_collect::<String>().await?;
+ Ok(serde_json::from_str(&response)?)
}
}
+
+ pub fn use_tool_stream<T: LanguageModelTool>(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let schema = schemars::schema_for!(T);
+ let schema_json = serde_json::to_value(&schema).unwrap();
+ self.use_any_tool(request, T::name(), T::description(), schema_json, cx)
+ }
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
@@ -1,13 +1,13 @@
use crate::{
- settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration,
+ LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
use anthropic::AnthropicError;
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
-use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
View, WhiteSpace,
@@ -38,6 +38,8 @@ pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
pub tool_override: Option<String>,
+ pub cache_configuration: Option<LanguageModelCacheConfiguration>,
+ pub max_output_tokens: Option<u32>,
}
pub struct AnthropicLanguageModelProvider {
@@ -171,6 +173,14 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
+ cache_configuration: model.cache_configuration.as_ref().map(|config| {
+ anthropic::AnthropicModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: config.should_speculate,
+ min_total_token: config.min_total_token,
+ }
+ }),
+ max_output_tokens: model.max_output_tokens,
},
);
}
@@ -264,29 +274,6 @@ pub fn count_anthropic_tokens(
}
impl AnthropicModel {
- fn request_completion(
- &self,
- request: anthropic::Request,
- cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<anthropic::Response>> {
- let http_client = self.http_client.clone();
-
- let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
- (state.api_key.clone(), settings.api_url.clone())
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
- async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
- anthropic::complete(http_client.as_ref(), &api_url, &api_key, request)
- .await
- .context("failed to retrieve completion")
- }
- .boxed()
- }
-
fn stream_completion(
&self,
request: anthropic::Request,
@@ -346,6 +333,10 @@ impl LanguageModel for AnthropicModel {
self.model.max_token_count()
}
+ fn max_output_tokens(&self) -> Option<u32> {
+ Some(self.model.max_output_tokens())
+ }
+
fn count_tokens(
&self,
request: LanguageModelRequest,
@@ -359,7 +350,8 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let request = request.into_anthropic(self.model.id().into());
+ let request =
+ request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
@@ -374,6 +366,16 @@ impl LanguageModel for AnthropicModel {
.boxed()
}
+ fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+ self.model
+ .cache_configuration()
+ .map(|config| LanguageModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: config.should_speculate,
+ min_total_token: config.min_total_token,
+ })
+ }
+
fn use_any_tool(
&self,
request: LanguageModelRequest,
@@ -381,8 +383,11 @@ impl LanguageModel for AnthropicModel {
tool_description: String,
input_schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
- let mut request = request.into_anthropic(self.model.tool_model_id().into());
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let mut request = request.into_anthropic(
+ self.model.tool_model_id().into(),
+ self.model.max_output_tokens(),
+ );
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@@ -392,25 +397,16 @@ impl LanguageModel for AnthropicModel {
input_schema,
}];
- let response = self.request_completion(request, cx);
+ let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
- response
- .content
- .into_iter()
- .find_map(|content| {
- if let anthropic::Content::ToolUse { name, input, .. } = content {
- if name == tool_name {
- Some(input)
- } else {
- None
- }
- } else {
- None
- }
- })
- .context("tool not used")
+ Ok(anthropic::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response.map_err(|e| anyhow!(e))),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
@@ -1,22 +1,25 @@
use super::open_ai::count_open_ai_tokens;
use crate::{
- settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
- LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+ settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
+ LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anthropic::AnthropicError;
-use anyhow::{anyhow, bail, Context as _, Result};
+use anyhow::{anyhow, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, ZedPro};
-use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
+use futures::{
+ future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
+ TryStreamExt as _,
+};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::{
@@ -53,6 +56,8 @@ pub struct AvailableModel {
name: String,
max_tokens: usize,
tool_override: Option<String>,
+ cache_configuration: Option<LanguageModelCacheConfiguration>,
+ max_output_tokens: Option<u32>,
}
pub struct CloudLanguageModelProvider {
@@ -199,6 +204,14 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
name: model.name.clone(),
max_tokens: model.max_tokens,
tool_override: model.tool_override.clone(),
+ cache_configuration: model.cache_configuration.as_ref().map(|config| {
+ anthropic::AnthropicModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: config.should_speculate,
+ min_total_token: config.min_total_token,
+ }
+ }),
+ max_output_tokens: model.max_output_tokens,
})
}
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
@@ -435,7 +448,7 @@ impl LanguageModel for CloudLanguageModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
- let request = request.into_anthropic(model.id().into());
+ let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
@@ -451,21 +464,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: anthropic::Event = serde_json::from_str(&buffer)
- .context("failed to parse Anthropic event")?;
- Ok(Some((event, body)))
- }
- Err(err) => Err(AnthropicError::Other(err.into())),
- }
- });
-
- Ok(anthropic::extract_text_from_events(stream))
+ Ok(anthropic::extract_text_from_events(
+ response_lines(response).map_err(AnthropicError::Other),
+ ))
});
async move {
Ok(future
@@ -492,21 +493,7 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: open_ai::ResponseStreamEvent =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(open_ai::extract_text_from_events(stream))
+ Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -527,21 +514,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: google_ai::GenerateContentResponse =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(google_ai::extract_text_from_events(stream))
+ Ok(google_ai::extract_text_from_events(response_lines(
+ response,
+ )))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -563,21 +538,7 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
- let body = BufReader::new(response.into_body());
- let stream = futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: open_ai::ResponseStreamEvent =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(open_ai::extract_text_from_events(stream))
+ Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@@ -591,11 +552,14 @@ impl LanguageModel for CloudLanguageModel {
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let client = self.client.clone();
+ let llm_api_token = self.llm_api_token.clone();
+
match &self.model {
CloudModel::Anthropic(model) => {
- let client = self.client.clone();
- let mut request = request.into_anthropic(model.tool_model_id().into());
+ let mut request =
+ request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});
@@ -605,7 +569,6 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
- let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@@ -621,70 +584,34 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
- let mut tool_use_index = None;
- let mut tool_input = String::new();
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- while body.read_line(&mut line).await? > 0 {
- let event: anthropic::Event = serde_json::from_str(&line)?;
- line.clear();
-
- match event {
- anthropic::Event::ContentBlockStart {
- content_block,
- index,
- } => {
- if let anthropic::Content::ToolUse { name, .. } = content_block
- {
- if name == tool_name {
- tool_use_index = Some(index);
- }
- }
- }
- anthropic::Event::ContentBlockDelta { index, delta } => match delta
- {
- anthropic::ContentDelta::TextDelta { .. } => {}
- anthropic::ContentDelta::InputJsonDelta { partial_json } => {
- if Some(index) == tool_use_index {
- tool_input.push_str(&partial_json);
- }
- }
- },
- anthropic::Event::ContentBlockStop { index } => {
- if Some(index) == tool_use_index {
- return Ok(serde_json::from_str(&tool_input)?);
- }
- }
- _ => {}
- }
- }
-
- if tool_use_index.is_some() {
- Err(anyhow!("tool content incomplete"))
- } else {
- Err(anyhow!("tool not used"))
- }
+ Ok(anthropic::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response_lines(response)),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
- let client = self.client.clone();
- let mut function = open_ai::FunctionDefinition {
- name: tool_name.clone(),
- description: None,
- parameters: None,
- };
- let func = open_ai::ToolDefinition::Function {
- function: function.clone(),
- };
- request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
- // Fill in description and params separately, as they're not needed for tool_choice field.
- function.description = Some(tool_description);
- function.parameters = Some(input_schema);
- request.tools = vec![open_ai::ToolDefinition::Function { function }];
+ request.tool_choice = Some(open_ai::ToolChoice::Other(
+ open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ },
+ },
+ ));
+ request.tools = vec![open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(input_schema),
+ },
+ }];
- let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@@ -700,41 +627,12 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- let mut load_state = None;
-
- while body.read_line(&mut line).await? > 0 {
- let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
- line.clear();
-
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
-
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
+ Ok(open_ai::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response_lines(response)),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
@@ -744,22 +642,23 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Zed(model) => {
// All Zed models are OpenAI-based at the time of writing.
let mut request = request.into_open_ai(model.id().into());
- let client = self.client.clone();
- let mut function = open_ai::FunctionDefinition {
- name: tool_name.clone(),
- description: None,
- parameters: None,
- };
- let func = open_ai::ToolDefinition::Function {
- function: function.clone(),
- };
- request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
- // Fill in description and params separately, as they're not needed for tool_choice field.
- function.description = Some(tool_description);
- function.parameters = Some(input_schema);
- request.tools = vec![open_ai::ToolDefinition::Function { function }];
+ request.tool_choice = Some(open_ai::ToolChoice::Other(
+ open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ },
+ },
+ ));
+ request.tools = vec![open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(input_schema),
+ },
+ }];
- let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@@ -775,40 +674,12 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- let mut load_state = None;
-
- while body.read_line(&mut line).await? > 0 {
- let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
- line.clear();
-
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
+ Ok(open_ai::extract_tool_args_from_events(
+ tool_name,
+ Box::pin(response_lines(response)),
+ )
+ .await?
+ .boxed())
})
.boxed()
}
@@ -816,6 +687,25 @@ impl LanguageModel for CloudLanguageModel {
}
}
+fn response_lines<T: DeserializeOwned>(
+ response: Response<AsyncBody>,
+) -> impl Stream<Item = Result<T>> {
+ futures::stream::try_unfold(
+ (String::new(), BufReader::new(response.into_body())),
+ move |(mut line, mut body)| async {
+ match body.read_line(&mut line).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: T = serde_json::from_str(&line)?;
+ line.clear();
+ Ok(Some((event, (line, body))))
+ }
+ Err(e) => Err(e.into()),
+ }
+ },
+ )
+}
+
impl LlmApiToken {
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;
@@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel {
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
@@ -3,16 +3,11 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
-use anyhow::Context as _;
-use futures::{
- channel::{mpsc, oneshot},
- future::BoxFuture,
- stream::BoxStream,
- FutureExt, StreamExt,
-};
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
use parking_lot::Mutex;
+use serde::Serialize;
use std::sync::Arc;
use ui::WindowContext;
@@ -90,7 +85,7 @@ pub struct ToolUseRequest {
#[derive(Default)]
pub struct FakeLanguageModel {
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
- current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
+ current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
}
impl FakeLanguageModel {
@@ -130,25 +125,11 @@ impl FakeLanguageModel {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
- pub fn respond_to_tool_use(
- &self,
- tool_call: &ToolUseRequest,
- response: Result<serde_json::Value>,
- ) {
- let mut current_tool_call_txs = self.current_tool_use_txs.lock();
- if let Some(index) = current_tool_call_txs
- .iter()
- .position(|(call, _)| call == tool_call)
- {
- let (_, tx) = current_tool_call_txs.remove(index);
- tx.send(response).unwrap();
- }
- }
-
- pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
+ pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
+ let response = serde_json::to_string(&response).unwrap();
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
let (_, tx) = current_tool_call_txs.pop().unwrap();
- tx.send(response).unwrap();
+ tx.unbounded_send(response).unwrap();
}
}
@@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel {
description: String,
schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
- let (tx, rx) = oneshot::channel();
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let (tx, rx) = mpsc::unbounded();
let tool_call = ToolUseRequest {
request,
name,
@@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel {
schema,
};
self.current_tool_use_txs.lock().push((tool_call, tx));
- async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
+ async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {
@@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel {
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}
@@ -6,7 +6,6 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall,
};
-use serde_json::Value;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
@@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel {
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
use ollama::{OllamaFunctionTool, OllamaTool};
let function = OllamaFunctionTool {
name: tool_name.clone(),
@@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel {
self.request_limiter
.run(async move {
let response = response.await?;
- let ChatMessage::Assistant {
- tool_calls,
- content,
- } = response.message
- else {
+ let ChatMessage::Assistant { tool_calls, .. } = response.message else {
bail!("message does not have an assistant role");
};
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
for call in tool_calls {
let OllamaToolCall::Function(function) = call;
if function.name == tool_name {
- return Ok(function.arguments);
+ return Ok(futures::stream::once(async move {
+ Ok(function.arguments.to_string())
+ })
+ .boxed());
}
}
- } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
- // Parse content as arguments.
- return Ok(args);
} else {
bail!("assistant message does not have any tool calls");
};
@@ -1,4 +1,4 @@
-use anyhow::{anyhow, bail, Result};
+use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@@ -243,6 +243,7 @@ impl OpenAiLanguageModel {
async move { Ok(future.await?.boxed()) }.boxed()
}
}
+
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel {
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
- ) -> BoxFuture<'static, Result<serde_json::Value>> {
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_open_ai(self.model.id().into());
- let mut function = FunctionDefinition {
- name: tool_name.clone(),
- description: None,
- parameters: None,
- };
- let func = ToolDefinition::Function {
- function: function.clone(),
- };
- request.tool_choice = Some(ToolChoice::Other(func.clone()));
- // Fill in description and params separately, as they're not needed for tool_choice field.
- function.description = Some(tool_description);
- function.parameters = Some(schema);
- request.tools = vec![ToolDefinition::Function { function }];
+ request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ },
+ }));
+ request.tools = vec![ToolDefinition::Function {
+ function: FunctionDefinition {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(schema),
+ },
+ }];
+
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
- let mut response = response.await?;
-
- // Call arguments are gonna be streamed in over multiple chunks.
- let mut load_state = None;
- while let Some(Ok(part)) = response.next().await {
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
+ let response = response.await?;
+ Ok(
+ open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
+ .await?
+ .boxed(),
+ )
})
.boxed()
}
@@ -193,6 +193,7 @@ impl From<&str> for MessageContent {
pub struct LanguageModelRequestMessage {
pub role: Role,
pub content: Vec<MessageContent>,
+ pub cache: bool,
}
impl LanguageModelRequestMessage {
@@ -213,7 +214,7 @@ impl LanguageModelRequestMessage {
.content
.get(0)
.map(|content| match content {
- MessageContent::Text(s) => s.is_empty(),
+ MessageContent::Text(s) => s.trim().is_empty(),
MessageContent::Image(_) => true,
})
.unwrap_or(false)
@@ -285,8 +286,8 @@ impl LanguageModelRequest {
}
}
- pub fn into_anthropic(self, model: String) -> anthropic::Request {
- let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+ pub fn into_anthropic(self, model: String, max_output_tokens: u32) -> anthropic::Request {
+ let mut new_messages: Vec<anthropic::Message> = Vec::new();
let mut system_message = String::new();
for message in self.messages {
@@ -296,18 +297,49 @@ impl LanguageModelRequest {
match message.role {
Role::User | Role::Assistant => {
+ let cache_control = if message.cache {
+ Some(anthropic::CacheControl {
+ cache_type: anthropic::CacheControlType::Ephemeral,
+ })
+ } else {
+ None
+ };
+ let anthropic_message_content: Vec<anthropic::Content> = message
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ MessageContent::Text(t) if !t.is_empty() => {
+ Some(anthropic::Content::Text {
+ text: t,
+ cache_control,
+ })
+ }
+ MessageContent::Image(i) => Some(anthropic::Content::Image {
+ source: anthropic::ImageSource {
+ source_type: "base64".to_string(),
+ media_type: "image/png".to_string(),
+ data: i.source.to_string(),
+ },
+ cache_control,
+ }),
+ _ => None,
+ })
+ .collect();
+ let anthropic_role = match message.role {
+ Role::User => anthropic::Role::User,
+ Role::Assistant => anthropic::Role::Assistant,
+ Role::System => unreachable!("System role should never occur here"),
+ };
if let Some(last_message) = new_messages.last_mut() {
- if last_message.role == message.role {
- // TODO: is this append done properly?
- last_message.content.push(MessageContent::Text(format!(
- "\n\n{}",
- message.string_contents()
- )));
+ if last_message.role == anthropic_role {
+ last_message.content.extend(anthropic_message_content);
continue;
}
}
-
- new_messages.push(message);
+ new_messages.push(anthropic::Message {
+ role: anthropic_role,
+ content: anthropic_message_content,
+ });
}
Role::System => {
if !system_message.is_empty() {
@@ -320,37 +352,8 @@ impl LanguageModelRequest {
anthropic::Request {
model,
- messages: new_messages
- .into_iter()
- .filter_map(|message| {
- Some(anthropic::Message {
- role: match message.role {
- Role::User => anthropic::Role::User,
- Role::Assistant => anthropic::Role::Assistant,
- Role::System => return None,
- },
- content: message
- .content
- .into_iter()
- // TODO: filter out the empty messages in the message construction step
- .filter_map(|content| match content {
- MessageContent::Text(t) if !t.is_empty() => {
- Some(anthropic::Content::Text { text: t })
- }
- MessageContent::Image(i) => Some(anthropic::Content::Image {
- source: anthropic::ImageSource {
- source_type: "base64".to_string(),
- media_type: "image/png".to_string(),
- data: i.source.to_string(),
- },
- }),
- _ => None,
- })
- .collect(),
- })
- })
- .collect(),
- max_tokens: 4092,
+ messages: new_messages,
+ max_tokens: max_output_tokens,
system: Some(system_message),
tools: Vec::new(),
tool_choice: None,
@@ -7,14 +7,17 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{update_settings_file, Settings, SettingsSources};
-use crate::provider::{
- self,
- anthropic::AnthropicSettings,
- cloud::{self, ZedDotDevSettings},
- copilot_chat::CopilotChatSettings,
- google::GoogleSettings,
- ollama::OllamaSettings,
- open_ai::OpenAiSettings,
+use crate::{
+ provider::{
+ self,
+ anthropic::AnthropicSettings,
+ cloud::{self, ZedDotDevSettings},
+ copilot_chat::CopilotChatSettings,
+ google::GoogleSettings,
+ ollama::OllamaSettings,
+ open_ai::OpenAiSettings,
+ },
+ LanguageModelCacheConfiguration,
};
/// Initializes the language model settings.
@@ -93,10 +96,20 @@ impl AnthropicSettingsContent {
name,
max_tokens,
tool_override,
+ cache_configuration,
+ max_output_tokens,
} => Some(provider::anthropic::AvailableModel {
name,
max_tokens,
tool_override,
+ cache_configuration: cache_configuration.as_ref().map(
+ |config| LanguageModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: config.should_speculate,
+ min_total_token: config.min_total_token,
+ },
+ ),
+ max_output_tokens,
}),
_ => None,
})
@@ -1,3 +1,5 @@
+#![allow(clippy::arc_with_non_send_sync)]
+
use std::sync::Arc;
#[cfg(all(target_os = "macos", not(any(test, feature = "test-support"))))]
@@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use serde_json::Value;
+use serde_json::{value::RawValue, Value};
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -92,7 +92,7 @@ impl Model {
}
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
@@ -107,16 +107,16 @@ pub enum ChatMessage {
},
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum OllamaToolCall {
Function(OllamaFunctionCall),
}
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[derive(Serialize, Deserialize, Debug)]
pub struct OllamaFunctionCall {
pub name: String,
- pub arguments: Value,
+ pub arguments: Box<RawValue>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::Value;
-use std::{convert::TryFrom, future::Future, time::Duration};
+use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
use strum::EnumIter;
pub use supported_countries::*;
@@ -384,6 +384,57 @@ pub fn embed<'a>(
}
}
+pub async fn extract_tool_args_from_events(
+ tool_name: String,
+ mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+) -> Result<impl Send + Stream<Item = Result<String>>> {
+ let mut tool_use_index = None;
+ let mut first_chunk = None;
+ while let Some(event) = events.next().await {
+ let call = event?.choices.into_iter().find_map(|choice| {
+ choice.delta.tool_calls?.into_iter().find_map(|call| {
+ if call.function.as_ref()?.name.as_deref()? == tool_name {
+ Some(call)
+ } else {
+ None
+ }
+ })
+ });
+ if let Some(call) = call {
+ tool_use_index = Some(call.index);
+ first_chunk = call.function.and_then(|func| func.arguments);
+ break;
+ }
+ }
+
+ let Some(tool_use_index) = tool_use_index else {
+ return Err(anyhow!("tool not used"));
+ };
+
+ Ok(events.filter_map(move |event| {
+ let result = match event {
+ Err(error) => Some(Err(error)),
+ Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
+ choice.delta.tool_calls?.into_iter().find_map(|call| {
+ if call.index == tool_use_index {
+ let func = call.function?;
+ let mut arguments = func.arguments?;
+ if let Some(mut first_chunk) = first_chunk.take() {
+ first_chunk.push_str(&arguments);
+ arguments = first_chunk
+ }
+ Some(Ok(arguments))
+ } else {
+ None
+ }
+ })
+ }),
+ };
+
+ async move { result }
+ }))
+}
+
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {
@@ -12,6 +12,7 @@ pub mod worktree_store;
#[cfg(test)]
mod project_tests;
+
pub mod search_history;
mod yarn;
@@ -450,9 +451,10 @@ pub struct Completion {
/// The raw completion provided by the language server.
pub lsp_completion: lsp::CompletionItem,
/// An optional callback to invoke when this completion is confirmed.
- pub confirm: Option<Arc<dyn Send + Sync + Fn(CompletionIntent, &mut WindowContext)>>,
- /// If true, the editor will show a new completion menu after this completion is confirmed.
- pub show_new_completions_on_confirm: bool,
+ /// Returns, whether new completions should be retriggered after the current one.
+ /// If `true` is returned, the editor will show a new completion menu after this completion is confirmed.
+ /// if no confirmation is provided or `false` is returned, the completion will be committed.
+ pub confirm: Option<Arc<dyn Send + Sync + Fn(CompletionIntent, &mut WindowContext) -> bool>>,
}
impl std::fmt::Debug for Completion {
@@ -9128,7 +9130,6 @@ impl Project {
filter_range: Default::default(),
},
confirm: None,
- show_new_completions_on_confirm: false,
},
false,
cx,
@@ -10765,7 +10766,6 @@ async fn populate_labels_for_completions(
documentation,
lsp_completion,
confirm: None,
- show_new_completions_on_confirm: false,
})
}
}
@@ -5204,7 +5204,7 @@ async fn search(
.collect())
}
-fn init_test(cx: &mut gpui::TestAppContext) {
+pub fn init_test(cx: &mut gpui::TestAppContext) {
if std::env::var("RUST_LOG").is_ok() {
env_logger::try_init().ok();
}
@@ -260,7 +260,7 @@ impl Project {
.and_then(|(worktree, relative_path)| {
worktree.read(cx).entry_for_path(&relative_path)
})
- .is_some()
+ .is_some_and(|entry| entry.is_dir())
})
}
@@ -250,24 +250,38 @@ pub struct ErrorView {
}
impl ErrorView {
- fn render(&self, cx: &ViewContext<ExecutionView>) -> Option<AnyElement> {
+ fn render(&self, cx: &mut ViewContext<ExecutionView>) -> Option<AnyElement> {
let theme = cx.theme();
let padding = cx.line_height() / 2.;
Some(
v_flex()
- .w_full()
- .px(padding)
- .py(padding)
- .border_1()
- .border_color(theme.status().error_border)
+ .gap_3()
.child(
h_flex()
- .font_weight(FontWeight::BOLD)
- .child(format!("{}: {}", self.ename, self.evalue)),
+ .font_buffer(cx)
+ .child(
+ Label::new(format!("{}: ", self.ename.clone()))
+ // .size(LabelSize::Large)
+ .color(Color::Error)
+ .weight(FontWeight::BOLD),
+ )
+ .child(
+ Label::new(self.evalue.clone())
+ // .size(LabelSize::Large)
+ .weight(FontWeight::BOLD),
+ ),
+ )
+ .child(
+ div()
+ .w_full()
+ .px(padding)
+ .py(padding)
+ .border_l_1()
+ .border_color(theme.status().error_border)
+ .child(self.traceback.render(cx)),
)
- .child(self.traceback.render(cx))
.into_any_element(),
)
}
@@ -358,7 +372,7 @@ pub enum OutputContent {
}
impl OutputContent {
- fn render(&self, cx: &ViewContext<ExecutionView>) -> Option<AnyElement> {
+ fn render(&self, cx: &mut ViewContext<ExecutionView>) -> Option<AnyElement> {
let el = match self {
// Note: in typical frontends we would show the execute_result.execution_count
// Here we can just handle either
@@ -1,9 +1,9 @@
use crate::components::KernelListItem;
-use crate::KernelStatus;
use crate::{
kernels::{Kernel, KernelSpecification, RunningKernel},
outputs::{ExecutionStatus, ExecutionView},
};
+use crate::{stdio, KernelStatus};
use client::telemetry::Telemetry;
use collections::{HashMap, HashSet};
use editor::{
@@ -26,9 +26,8 @@ use runtimelib::{
ExecuteRequest, ExecutionState, InterruptRequest, JupyterMessage, JupyterMessageContent,
ShutdownRequest,
};
-use settings::Settings as _;
use std::{env::temp_dir, ops::Range, sync::Arc, time::Duration};
-use theme::{ActiveTheme, ThemeSettings};
+use theme::ActiveTheme;
use ui::{prelude::*, IconButtonShape, Tooltip};
pub struct Session {
@@ -114,68 +113,68 @@ impl EditorBlock {
) -> RenderBlock {
let render = move |cx: &mut BlockContext| {
let execution_view = execution_view.clone();
- let text_font = ThemeSettings::get_global(cx).buffer_font.family.clone();
- let text_font_size = ThemeSettings::get_global(cx).buffer_font_size;
+ let text_style = stdio::text_style(cx);
let gutter = cx.gutter_dimensions;
- let close_button_size = IconSize::XSmall;
let block_id = cx.block_id;
let on_close = on_close.clone();
let rem_size = cx.rem_size();
- let line_height = cx.text_style().line_height_in_pixels(rem_size);
- let (close_button_width, close_button_padding) =
- close_button_size.square_components(cx);
+ let text_line_height = text_style.line_height_in_pixels(rem_size);
+
+ let close_button = h_flex()
+ .flex_none()
+ .items_center()
+ .justify_center()
+ .absolute()
+ .top(text_line_height / 2.)
+ .right(
+ // 2px is a magic number to nudge the button just a bit closer to
+ // the line number start
+ gutter.full_width() / 2.0 - text_line_height / 2.0 - px(2.),
+ )
+ .w(text_line_height)
+ .h(text_line_height)
+ .child(
+ IconButton::new(
+ ("close_output_area", EntityId::from(cx.block_id)),
+ IconName::Close,
+ )
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .size(ButtonSize::Compact)
+ .shape(IconButtonShape::Square)
+ .tooltip(|cx| Tooltip::text("Close output area", cx))
+ .on_click(move |_, cx| {
+ if let BlockId::Custom(block_id) = block_id {
+ (on_close)(block_id, cx)
+ }
+ }),
+ );
div()
- .min_h(line_height)
.flex()
- .flex_row()
.items_start()
+ .min_h(text_line_height)
.w_full()
- .bg(cx.theme().colors().background)
.border_y_1()
.border_color(cx.theme().colors().border)
+ .bg(cx.theme().colors().background)
.child(
- v_flex().min_h(cx.line_height()).justify_center().child(
- h_flex()
- .w(gutter.full_width())
- .justify_end()
- .pt(line_height / 2.)
- .child(
- h_flex()
- .pr(gutter.width / 2. - close_button_width
- + close_button_padding / 2.)
- .child(
- IconButton::new(
- ("close_output_area", EntityId::from(cx.block_id)),
- IconName::Close,
- )
- .shape(IconButtonShape::Square)
- .icon_size(close_button_size)
- .icon_color(Color::Muted)
- .tooltip(|cx| Tooltip::text("Close output area", cx))
- .on_click(
- move |_, cx| {
- if let BlockId::Custom(block_id) = block_id {
- (on_close)(block_id, cx)
- }
- },
- ),
- ),
- ),
- ),
+ div()
+ .relative()
+ .w(gutter.full_width())
+ .h(text_line_height * 2)
+ .child(close_button),
)
.child(
div()
.flex_1()
.size_full()
- .my_2()
+ .py(text_line_height / 2.)
.mr(gutter.width)
- .text_size(text_font_size)
- .font_family(text_font)
.child(execution_view),
)
.into_any_element()
@@ -1,9 +1,11 @@
use crate::outputs::ExecutionView;
use alacritty_terminal::{term::Config, vte::ansi::Processor};
-use gpui::{canvas, size, AnyElement};
+use gpui::{canvas, size, AnyElement, FontStyle, TextStyle, WhiteSpace};
+use settings::Settings as _;
use std::mem;
use terminal::ZedListener;
use terminal_view::terminal_element::TerminalElement;
+use theme::ThemeSettings;
use ui::{prelude::*, IntoElement, ViewContext};
/// Implements the most basic of terminal output for use by Jupyter outputs
@@ -22,8 +24,38 @@ pub struct TerminalOutput {
const DEFAULT_NUM_LINES: usize = 32;
const DEFAULT_NUM_COLUMNS: usize = 128;
+pub fn text_style(cx: &mut WindowContext) -> TextStyle {
+ let settings = ThemeSettings::get_global(cx).clone();
+
+ let font_family = settings.buffer_font.family;
+ let font_features = settings.buffer_font.features;
+ let font_weight = settings.buffer_font.weight;
+ let font_fallbacks = settings.buffer_font.fallbacks;
+
+ let theme = cx.theme();
+
+ let text_style = TextStyle {
+ font_family,
+ font_features,
+ font_weight,
+ font_fallbacks,
+ font_size: theme::get_buffer_font_size(cx).into(),
+ font_style: FontStyle::Normal,
+ // todo
+ line_height: cx.line_height().into(),
+ background_color: Some(theme.colors().terminal_background),
+ white_space: WhiteSpace::Normal,
+ // These are going to be overridden per-cell
+ underline: None,
+ strikethrough: None,
+ color: theme.colors().terminal_foreground,
+ };
+
+ text_style
+}
+
pub fn terminal_size(cx: &mut WindowContext) -> terminal::TerminalSize {
- let text_style = cx.text_style();
+ let text_style = text_style(cx);
let text_system = cx.text_system();
let line_height = cx.line_height();
@@ -86,8 +118,8 @@ impl TerminalOutput {
}
}
- pub fn render(&self, cx: &ViewContext<ExecutionView>) -> AnyElement {
- let text_style = cx.text_style();
+ pub fn render(&self, cx: &mut ViewContext<ExecutionView>) -> AnyElement {
+ let text_style = text_style(cx);
let text_system = cx.text_system();
let grid = self
@@ -101,10 +133,9 @@ impl TerminalOutput {
let (cells, rects) = TerminalElement::layout_grid(grid, &text_style, text_system, None, cx);
// lines are 0-indexed, so we must add 1 to get the number of lines
+ let text_line_height = text_style.line_height_in_pixels(cx.rem_size());
let num_lines = cells.iter().map(|c| c.point.line).max().unwrap_or(0) + 1;
- let height = num_lines as f32 * cx.line_height();
-
- let line_height = cx.line_height();
+ let height = num_lines as f32 * text_line_height;
let font_pixels = text_style.font_size.to_pixels(cx.rem_size());
let font_id = text_system.resolve_font(&text_style.font());
@@ -124,7 +155,7 @@ impl TerminalOutput {
bounds.origin,
&terminal::TerminalSize {
cell_width,
- line_height,
+ line_height: text_line_height,
size: bounds.size,
},
cx,
@@ -136,7 +167,7 @@ impl TerminalOutput {
bounds.origin,
&terminal::TerminalSize {
cell_width,
- line_height,
+ line_height: text_line_height,
size: bounds.size,
},
bounds,
@@ -263,7 +263,7 @@ impl PathMatcher {
let path_str = path.to_string_lossy();
let separator = std::path::MAIN_SEPARATOR_STR;
if path_str.ends_with(separator) {
- self.glob.is_match(path)
+ return false;
} else {
self.glob.is_match(path_str.to_string() + separator)
}
@@ -2,7 +2,7 @@
description = "The fast, collaborative code editor."
edition = "2021"
name = "zed"
-version = "0.149.0"
+version = "0.150.0"
publish = false
license = "GPL-3.0-or-later"
authors = ["Zed Team <hi@zed.dev>"]
@@ -38,6 +38,7 @@
- [Developing Extensions](./extensions/developing-extensions.md)
- [Language Extensions](./extensions/languages.md)
- [Theme Extensions](./extensions/themes.md)
+- [Slash Commands](./extensions/slash-commands.md)
# Language Support
@@ -6,3 +6,4 @@ Zed lets you add new functionality using user-defined extensions.
- [Developing Extensions](./extensions/developing-extensions.md)
- [Developing Language Extensions](./extensions/languages.md)
- [Developing Themes](./extensions/themes.md)
+ - [Developing Slash Commands](./extensions/slash-commands.md)
@@ -6,6 +6,7 @@ Extensions can add the following capabilities to Zed:
- [Languages](./languages.md)
- [Themes](./themes.md)
+- [Slash Commands](./slash-commands.md)
## Directory Structure of a Zed Extension
@@ -22,10 +23,6 @@ description = "My cool extension"
repository = "https://github.com/your-name/my-zed-extension"
```
-<!--
-TBD: Document `slash_commands`, `indexed_docs_providers` (see: extensions/gleam/extension.toml)
--->
-
In addition to this, there are several other optional files and directories that can be used to add functionality to a Zed extension. An example directory structure of an extension that provides all capabilities is as follows:
```
@@ -0,0 +1,138 @@
+# Slash Commands
+
+Extensions may provide slash commands for use in the Assistant.
+
+## Example extension
+
+To see a working example of an extension that provides slash commands, check out the [`slash-commands-example` extension](https://github.com/zed-industries/zed/tree/main/extensions/slash-commands-example).
+
+This extension can be [installed as a dev extension](./developing-extensions.html#developing-an-extension-locally) if you want to try it out for yourself.
+
+## Defining slash commands
+
+A given extension may provide one or more slash commands. Each slash command must be registered in the `extension.toml`.
+
+For example, here is an extension that provides two slash commands: `/echo` and `/pick-one`:
+
+```toml
+[slash_commands.echo]
+description = "echoes the provided input"
+requires_argument = true
+
+[slash_commands.pick-one]
+description = "pick one of three options"
+requires_argument = true
+```
+
+Each slash command may define the following properties:
+
+- `description`: A description of the slash command that will be shown when completing available commands.
+- `requires_argument`: Indicates whether a slash command requires at least one argument to run.
+
+## Implementing slash command behavior
+
+To implement behavior for your slash commands, implement `run_slash_command` for your extension.
+
+This method accepts the slash command that will be run, the list of arguments passed to it, and an optional `Worktree`.
+
+This method returns `SlashCommandOutput`, which contains the textual output of the command in the `text` field. The output may also define `SlashCommandOutputSection`s that contain ranges into the output. These sections are then rendered as creases in the Assistant's context editor.
+
+Your extension should `match` on the command name (without the leading `/`) and then execute behavior accordingly:
+
+```rs
+impl zed::Extension for MyExtension {
+ fn run_slash_command(
+ &self,
+ command: SlashCommand,
+ args: Vec<String>,
+ _worktree: Option<&Worktree>,
+ ) -> Result<SlashCommandOutput, String> {
+ match command.name.as_str() {
+ "echo" => {
+ if args.is_empty() {
+ return Err("nothing to echo".to_string());
+ }
+
+ let text = args.join(" ");
+
+ Ok(SlashCommandOutput {
+ sections: vec![SlashCommandOutputSection {
+ range: (0..text.len()).into(),
+ label: "Echo".to_string(),
+ }],
+ text,
+ })
+ }
+ "pick-one" => {
+ let Some(selection) = args.first() else {
+ return Err("no option selected".to_string());
+ };
+
+ match selection.as_str() {
+ "option-1" | "option-2" | "option-3" => {}
+ invalid_option => {
+ return Err(format!("{invalid_option} is not a valid option"));
+ }
+ }
+
+ let text = format!("You chose {selection}.");
+
+ Ok(SlashCommandOutput {
+ sections: vec![SlashCommandOutputSection {
+ range: (0..text.len()).into(),
+ label: format!("Pick One: {selection}"),
+ }],
+ text,
+ })
+ }
+ command => Err(format!("unknown slash command: \"{command}\"")),
+ }
+ }
+}
+```
+
+## Auto-completing slash command arguments
+
+For slash commands that have arguments, you may also choose to implement `complete_slash_command_argument` to provide completions for your slash commands.
+
+This method accepts the slash command that will be run and the list of arguments passed to it. It returns a list of `SlashCommandArgumentCompletion`s that will be shown in the completion menu.
+
+A `SlashCommandArgumentCompletion` consists of the following properties:
+
+- `label`: The label that will be shown in the completion menu.
+- `new_text`: The text that will be inserted when the completion is accepted.
+- `run_command`: Whether the slash command will be run when the completion is accepted.
+
+Once again, your extension should `match` on the command name (without the leading `/`) and return the desired argument completions:
+
+```rs
+impl zed::Extension for MyExtension {
+ fn complete_slash_command_argument(
+ &self,
+ command: SlashCommand,
+ _args: Vec<String>,
+ ) -> Result<Vec<SlashCommandArgumentCompletion>, String> {
+ match command.name.as_str() {
+ "echo" => Ok(vec![]),
+ "pick-one" => Ok(vec![
+ SlashCommandArgumentCompletion {
+ label: "Option One".to_string(),
+ new_text: "option-1".to_string(),
+ run_command: true,
+ },
+ SlashCommandArgumentCompletion {
+ label: "Option Two".to_string(),
+ new_text: "option-2".to_string(),
+ run_command: true,
+ },
+ SlashCommandArgumentCompletion {
+ label: "Option Three".to_string(),
+ new_text: "option-3".to_string(),
+ run_command: true,
+ },
+ ]),
+ command => Err(format!("unknown slash command: \"{command}\"")),
+ }
+ }
+}
+```
@@ -1,6 +1,6 @@
[package]
name = "zed_gleam"
-version = "0.1.3"
+version = "0.2.0"
edition = "2021"
publish = false
license = "Apache-2.0"
@@ -14,4 +14,4 @@ crate-type = ["cdylib"]
[dependencies]
html_to_markdown = "0.1.0"
-zed_extension_api = { path = "../../crates/extension_api" }
+zed_extension_api = "0.1.0"
@@ -1,7 +1,7 @@
id = "gleam"
name = "Gleam"
description = "Gleam support."
-version = "0.1.3"
+version = "0.2.0"
schema_version = 1
authors = ["Marshall Bowers <elliott.codes@gmail.com>"]
repository = "https://github.com/zed-industries/zed"
@@ -17,6 +17,5 @@ commit = "426e67087fd62be5f4533581b5916b2cf010fb5b"
[slash_commands.gleam-project]
description = "Returns information about the current Gleam project."
requires_argument = false
-tooltip_text = "Insert Gleam project data"
[indexed_docs_providers.gleam-hexdocs]
@@ -0,0 +1,16 @@
+[package]
+name = "slash_commands_example"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "Apache-2.0"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/slash_commands_example.rs"
+crate-type = ["cdylib"]
+
+[dependencies]
+zed_extension_api = "0.1.0"
@@ -0,0 +1 @@
+../../LICENSE-APACHE
@@ -0,0 +1,3 @@
+# Slash Commands Example Extension
+
+This is an example extension showcasing how to write slash commands.
@@ -0,0 +1,15 @@
+id = "slash-commands-example"
+name = "Slash Commands Example"
+description = "An example extension showcasing slash commands."
+version = "0.1.0"
+schema_version = 1
+authors = ["Zed Industries <hi@zed.dev>"]
+repository = "https://github.com/zed-industries/zed"
+
+[slash_commands.echo]
+description = "echoes the provided input"
+requires_argument = true
+
+[slash_commands.pick-one]
+description = "pick one of three options"
+requires_argument = true
@@ -0,0 +1,90 @@
+use zed_extension_api::{
+ self as zed, SlashCommand, SlashCommandArgumentCompletion, SlashCommandOutput,
+ SlashCommandOutputSection, Worktree,
+};
+
+struct SlashCommandsExampleExtension;
+
+impl zed::Extension for SlashCommandsExampleExtension {
+ fn new() -> Self {
+ SlashCommandsExampleExtension
+ }
+
+ fn complete_slash_command_argument(
+ &self,
+ command: SlashCommand,
+ _args: Vec<String>,
+ ) -> Result<Vec<zed_extension_api::SlashCommandArgumentCompletion>, String> {
+ match command.name.as_str() {
+ "echo" => Ok(vec![]),
+ "pick-one" => Ok(vec![
+ SlashCommandArgumentCompletion {
+ label: "Option One".to_string(),
+ new_text: "option-1".to_string(),
+ run_command: true,
+ },
+ SlashCommandArgumentCompletion {
+ label: "Option Two".to_string(),
+ new_text: "option-2".to_string(),
+ run_command: true,
+ },
+ SlashCommandArgumentCompletion {
+ label: "Option Three".to_string(),
+ new_text: "option-3".to_string(),
+ run_command: true,
+ },
+ ]),
+ command => Err(format!("unknown slash command: \"{command}\"")),
+ }
+ }
+
+ fn run_slash_command(
+ &self,
+ command: SlashCommand,
+ args: Vec<String>,
+ _worktree: Option<&Worktree>,
+ ) -> Result<SlashCommandOutput, String> {
+ match command.name.as_str() {
+ "echo" => {
+ if args.is_empty() {
+ return Err("nothing to echo".to_string());
+ }
+
+ let text = args.join(" ");
+
+ Ok(SlashCommandOutput {
+ sections: vec![SlashCommandOutputSection {
+ range: (0..text.len()).into(),
+ label: "Echo".to_string(),
+ }],
+ text,
+ })
+ }
+ "pick-one" => {
+ let Some(selection) = args.first() else {
+ return Err("no option selected".to_string());
+ };
+
+ match selection.as_str() {
+ "option-1" | "option-2" | "option-3" => {}
+ invalid_option => {
+ return Err(format!("{invalid_option} is not a valid option"));
+ }
+ }
+
+ let text = format!("You chose {selection}.");
+
+ Ok(SlashCommandOutput {
+ sections: vec![SlashCommandOutputSection {
+ range: (0..text.len()).into(),
+ label: format!("Pick One: {selection}"),
+ }],
+ text,
+ })
+ }
+ command => Err(format!("unknown slash command: \"{command}\"")),
+ }
+ }
+}
+
+zed::register_extension!(SlashCommandsExampleExtension);
@@ -1,6 +1,6 @@
[package]
name = "zed_zig"
-version = "0.1.5"
+version = "0.2.0"
edition = "2021"
publish = false
license = "Apache-2.0"
@@ -13,4 +13,4 @@ path = "src/zig.rs"
crate-type = ["cdylib"]
[dependencies]
-zed_extension_api = "0.0.6"
+zed_extension_api = "0.1.0"
@@ -1,7 +1,7 @@
id = "zig"
name = "Zig"
description = "Zig support."
-version = "0.1.5"
+version = "0.2.0"
schema_version = 1
authors = ["Allan Calix <contact@acx.dev>"]
repository = "https://github.com/zed-industries/zed"
@@ -61,21 +61,11 @@ impl ZigExtension {
&language_server_id,
&zed::LanguageServerInstallationStatus::CheckingForUpdate,
);
- // TODO: Once we're ready to release v0.0.7 of the Zed extension API we want to pin
- // ZLS to a specific version with `zed::github_release_by_tag_name`.
// We're pinning ZLS to a release that has `.tar.gz` assets, since the latest release does not have
// them, at time of writing.
//
// ZLS tracking issue: https://github.com/zigtools/zls/issues/1879
- // let release = zed::github_release_by_tag_name("zigtools/zls", "0.11.0")?;
-
- let release = zed::latest_github_release(
- "zigtools/zls",
- zed::GithubReleaseOptions {
- require_assets: true,
- pre_release: false,
- },
- )?;
+ let release = zed::github_release_by_tag_name("zigtools/zls", "0.11.0")?;
let asset_name = format!(
"zls-{arch}-{os}.{extension}",
@@ -1,4 +1,4 @@
-import { danger, warn } from "danger";
+import { danger, message, warn } from "danger";
const { prHygiene } = require("danger-plugin-pr-hygiene");
prHygiene({
@@ -37,14 +37,14 @@ if (!hasReleaseNotes) {
}
const ISSUE_LINK_PATTERN = new RegExp(
- "(?:https://github\\.com/[\\w-]+/[\\w-]+/issues/\\d+|#\\d+)",
+ "https://github\\.com/[\\w-]+/[\\w-]+/issues/\\d+",
"g",
);
const includesIssueUrl = ISSUE_LINK_PATTERN.test(body);
if (includesIssueUrl) {
- const matches = body.match(ISSUE_LINK_PATTERN);
+ const matches = body.match(ISSUE_LINK_PATTERN) ?? [];
const issues = matches
.map((match) =>
match
@@ -53,7 +53,7 @@ if (includesIssueUrl) {
)
.filter((issue, index, self) => self.indexOf(issue) === index);
- warn(
+ message(
[
"This PR includes links to the following GitHub Issues: " +
issues.map((issue) => `#${issue}`).join(", "),