Detailed changes
@@ -7713,6 +7713,7 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
+ "partial-json-fixer",
"project",
"proto",
"schemars",
@@ -9828,6 +9829,12 @@ dependencies = [
"windows-targets 0.52.6",
]
+[[package]]
+name = "partial-json-fixer"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
+
[[package]]
name = "password-hash"
version = "0.4.2"
@@ -480,6 +480,7 @@ num-format = "0.4.4"
ordered-float = "2.1.1"
palette = { version = "0.7.5", default-features = false, features = ["std"] }
parking_lot = "0.12.1"
+partial-json-fixer = "0.5.3"
pathdiff = "0.2"
pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
@@ -266,14 +266,6 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
}
}
-fn render_tool_use_markdown(
- text: SharedString,
- language_registry: Arc<LanguageRegistry>,
- cx: &mut App,
-) -> Entity<Markdown> {
- cx.new(|cx| Markdown::new(text, Some(language_registry), None, cx))
-}
-
fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
@@ -867,21 +859,34 @@ impl ActiveThread {
tool_output: SharedString,
cx: &mut Context<Self>,
) {
- let rendered = RenderedToolUse {
- label: render_tool_use_markdown(tool_label.into(), self.language_registry.clone(), cx),
- input: render_tool_use_markdown(
- format!(
- "```json\n{}\n```",
- serde_json::to_string_pretty(tool_input).unwrap_or_default()
- )
- .into(),
- self.language_registry.clone(),
- cx,
- ),
- output: render_tool_use_markdown(tool_output, self.language_registry.clone(), cx),
- };
- self.rendered_tool_uses
- .insert(tool_use_id.clone(), rendered);
+ let rendered = self
+ .rendered_tool_uses
+ .entry(tool_use_id.clone())
+ .or_insert_with(|| RenderedToolUse {
+ label: cx.new(|cx| {
+ Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
+ }),
+ input: cx.new(|cx| {
+ Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
+ }),
+ output: cx.new(|cx| {
+ Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
+ }),
+ });
+
+ rendered.label.update(cx, |this, cx| {
+ this.replace(tool_label, cx);
+ });
+ rendered.input.update(cx, |this, cx| {
+ let input = format!(
+ "```json\n{}\n```",
+ serde_json::to_string_pretty(tool_input).unwrap_or_default()
+ );
+ this.replace(input, cx);
+ });
+ rendered.output.update(cx, |this, cx| {
+ this.replace(tool_output, cx);
+ });
}
fn handle_thread_event(
@@ -974,6 +979,19 @@ impl ActiveThread {
);
}
}
+ ThreadEvent::StreamedToolUse {
+ tool_use_id,
+ ui_text,
+ input,
+ } => {
+ self.render_tool_use_markdown(
+ tool_use_id.clone(),
+ ui_text.clone(),
+ input,
+ "".into(),
+ cx,
+ );
+ }
ThreadEvent::ToolFinished {
pending_tool_use, ..
} => {
@@ -2478,13 +2496,15 @@ impl ActiveThread {
let edit_tools = tool_use.needs_confirmation;
let status_icons = div().child(match &tool_use.status {
- ToolUseStatus::Pending | ToolUseStatus::NeedsConfirmation => {
+ ToolUseStatus::NeedsConfirmation => {
let icon = Icon::new(IconName::Warning)
.color(Color::Warning)
.size(IconSize::Small);
icon.into_any_element()
}
- ToolUseStatus::Running => {
+ ToolUseStatus::Pending
+ | ToolUseStatus::InputStillStreaming
+ | ToolUseStatus::Running => {
let icon = Icon::new(IconName::ArrowCircle)
.color(Color::Accent)
.size(IconSize::Small);
@@ -2570,7 +2590,7 @@ impl ActiveThread {
}),
)),
),
- ToolUseStatus::Running => container.child(
+ ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child(
results_content_container().child(
h_flex()
.gap_1()
@@ -1293,12 +1293,27 @@ impl Thread {
thread.insert_message(Role::Assistant, vec![], cx)
});
- thread.tool_use.request_tool_use(
+ let tool_use_id = tool_use.id.clone();
+ let streamed_input = if tool_use.is_input_complete {
+ None
+ } else {
+ Some((&tool_use.input).clone())
+ };
+
+ let ui_text = thread.tool_use.request_tool_use(
last_assistant_message_id,
tool_use,
tool_use_metadata.clone(),
cx,
);
+
+ if let Some(input) = streamed_input {
+ cx.emit(ThreadEvent::StreamedToolUse {
+ tool_use_id,
+ ui_text,
+ input,
+ });
+ }
}
}
@@ -2189,6 +2204,11 @@ pub enum ThreadEvent {
StreamedCompletion,
StreamedAssistantText(MessageId, String),
StreamedAssistantThinking(MessageId, String),
+ StreamedToolUse {
+ tool_use_id: LanguageModelToolUseId,
+ ui_text: Arc<str>,
+ input: serde_json::Value,
+ },
Stopped(Result<StopReason, Arc<anyhow::Error>>),
MessageAdded(MessageId),
MessageEdited(MessageId),
@@ -75,6 +75,7 @@ impl ToolUseState {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
input: tool_use.input.clone(),
+ is_input_complete: true,
})
.collect::<Vec<_>>();
@@ -176,6 +177,9 @@ impl ToolUseState {
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
+ PendingToolUseStatus::InputStillStreaming => {
+ ToolUseStatus::InputStillStreaming
+ }
}
} else {
ToolUseStatus::Pending
@@ -192,7 +196,12 @@ impl ToolUseState {
tool_uses.push(ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone().into(),
- ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
+ ui_text: self.tool_ui_label(
+ &tool_use.name,
+ &tool_use.input,
+ tool_use.is_input_complete,
+ cx,
+ ),
input: tool_use.input.clone(),
status,
icon,
@@ -207,10 +216,15 @@ impl ToolUseState {
&self,
tool_name: &str,
input: &serde_json::Value,
+ is_input_complete: bool,
cx: &App,
) -> SharedString {
if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
- tool.ui_text(input).into()
+ if is_input_complete {
+ tool.ui_text(input).into()
+ } else {
+ tool.still_streaming_ui_text(input).into()
+ }
} else {
format!("Unknown tool {tool_name:?}").into()
}
@@ -258,22 +272,50 @@ impl ToolUseState {
tool_use: LanguageModelToolUse,
metadata: ToolUseMetadata,
cx: &App,
- ) {
- self.tool_uses_by_assistant_message
+ ) -> Arc<str> {
+ let tool_uses = self
+ .tool_uses_by_assistant_message
.entry(assistant_message_id)
- .or_default()
- .push(tool_use.clone());
+ .or_default();
- self.tool_use_metadata_by_id
- .insert(tool_use.id.clone(), metadata);
+ let mut existing_tool_use_found = false;
- // The tool use is being requested by the Assistant, so we want to
- // attach the tool results to the next user message.
- let next_user_message_id = MessageId(assistant_message_id.0 + 1);
- self.tool_uses_by_user_message
- .entry(next_user_message_id)
- .or_default()
- .push(tool_use.id.clone());
+ for existing_tool_use in tool_uses.iter_mut() {
+ if existing_tool_use.id == tool_use.id {
+ *existing_tool_use = tool_use.clone();
+ existing_tool_use_found = true;
+ }
+ }
+
+ if !existing_tool_use_found {
+ tool_uses.push(tool_use.clone());
+ }
+
+ let status = if tool_use.is_input_complete {
+ self.tool_use_metadata_by_id
+ .insert(tool_use.id.clone(), metadata);
+
+ // The tool use is being requested by the Assistant, so we want to
+ // attach the tool results to the next user message.
+ let next_user_message_id = MessageId(assistant_message_id.0 + 1);
+ self.tool_uses_by_user_message
+ .entry(next_user_message_id)
+ .or_default()
+ .push(tool_use.id.clone());
+
+ PendingToolUseStatus::Idle
+ } else {
+ PendingToolUseStatus::InputStillStreaming
+ };
+
+ let ui_text: Arc<str> = self
+ .tool_ui_label(
+ &tool_use.name,
+ &tool_use.input,
+ tool_use.is_input_complete,
+ cx,
+ )
+ .into();
self.pending_tool_uses_by_id.insert(
tool_use.id.clone(),
@@ -281,13 +323,13 @@ impl ToolUseState {
assistant_message_id,
id: tool_use.id,
name: tool_use.name.clone(),
- ui_text: self
- .tool_ui_label(&tool_use.name, &tool_use.input, cx)
- .into(),
+ ui_text: ui_text.clone(),
input: tool_use.input,
- status: PendingToolUseStatus::Idle,
+ status,
},
);
+
+ ui_text
}
pub fn run_pending_tool(
@@ -497,6 +539,7 @@ pub struct Confirmation {
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
+ InputStillStreaming,
Idle,
NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
@@ -30,6 +30,7 @@ pub fn init(cx: &mut App) {
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
+ InputStillStreaming,
NeedsConfirmation,
Pending,
Running,
@@ -41,6 +42,7 @@ impl ToolUseStatus {
pub fn text(&self) -> SharedString {
match self {
ToolUseStatus::NeedsConfirmation => "".into(),
+ ToolUseStatus::InputStillStreaming => "".into(),
ToolUseStatus::Pending => "".into(),
ToolUseStatus::Running => "".into(),
ToolUseStatus::Finished(out) => out.clone(),
@@ -148,6 +150,12 @@ pub trait Tool: 'static + Send + Sync {
/// Returns markdown to be displayed in the UI for this tool.
fn ui_text(&self, input: &serde_json::Value) -> String;
+ /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
+ /// (so information may be missing).
+ fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
+ self.ui_text(input)
+ }
+
/// Runs the tool with the provided input.
fn run(
self: Arc<Self>,
@@ -33,8 +33,18 @@ pub struct CreateFileToolInput {
pub contents: String,
}
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+struct PartialInput {
+ #[serde(default)]
+ path: String,
+ #[serde(default)]
+ contents: String,
+}
+
pub struct CreateFileTool;
+const DEFAULT_UI_TEXT: &str = "Create file";
+
impl Tool for CreateFileTool {
fn name(&self) -> String {
"create_file".into()
@@ -62,7 +72,14 @@ impl Tool for CreateFileTool {
let path = MarkdownString::inline_code(&input.path);
format!("Create file {path}")
}
- Err(_) => "Create file".to_string(),
+ Err(_) => DEFAULT_UI_TEXT.to_string(),
+ }
+ }
+
+ fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
+ match serde_json::from_value::<PartialInput>(input.clone()).ok() {
+ Some(input) if !input.path.is_empty() => input.path,
+ _ => DEFAULT_UI_TEXT.to_string(),
}
}
@@ -111,3 +128,60 @@ impl Tool for CreateFileTool {
.into()
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use serde_json::json;
+
+ #[test]
+ fn still_streaming_ui_text_with_path() {
+ let tool = CreateFileTool;
+ let input = json!({
+ "path": "src/main.rs",
+ "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
+ });
+
+ assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
+ }
+
+ #[test]
+ fn still_streaming_ui_text_without_path() {
+ let tool = CreateFileTool;
+ let input = json!({
+ "path": "",
+ "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
+ });
+
+ assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+ }
+
+ #[test]
+ fn still_streaming_ui_text_with_null() {
+ let tool = CreateFileTool;
+ let input = serde_json::Value::Null;
+
+ assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+ }
+
+ #[test]
+ fn ui_text_with_valid_input() {
+ let tool = CreateFileTool;
+ let input = json!({
+ "path": "src/main.rs",
+ "contents": "fn main() {\n println!(\"Hello, world!\");\n}"
+ });
+
+ assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
+ }
+
+ #[test]
+ fn ui_text_with_invalid_input() {
+ let tool = CreateFileTool;
+ let input = json!({
+ "invalid": "field"
+ });
+
+ assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
+ }
+}
@@ -47,8 +47,22 @@ pub struct EditFileToolInput {
pub new_string: String,
}
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+struct PartialInput {
+ #[serde(default)]
+ path: String,
+ #[serde(default)]
+ display_description: String,
+ #[serde(default)]
+ old_string: String,
+ #[serde(default)]
+ new_string: String,
+}
+
pub struct EditFileTool;
+const DEFAULT_UI_TEXT: &str = "Edit file";
+
impl Tool for EditFileTool {
fn name(&self) -> String {
"edit_file".into()
@@ -77,6 +91,22 @@ impl Tool for EditFileTool {
}
}
+ fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
+ if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
+ let description = input.display_description.trim();
+ if !description.is_empty() {
+ return description.to_string();
+ }
+
+ let path = input.path.trim();
+ if !path.is_empty() {
+ return path.to_string();
+ }
+ }
+
+ DEFAULT_UI_TEXT.to_string()
+ }
+
fn run(
self: Arc<Self>,
input: serde_json::Value,
@@ -181,3 +211,69 @@ impl Tool for EditFileTool {
}).into()
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use serde_json::json;
+
+ #[test]
+ fn still_streaming_ui_text_with_path() {
+ let tool = EditFileTool;
+ let input = json!({
+ "path": "src/main.rs",
+ "display_description": "",
+ "old_string": "old code",
+ "new_string": "new code"
+ });
+
+ assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
+ }
+
+ #[test]
+ fn still_streaming_ui_text_with_description() {
+ let tool = EditFileTool;
+ let input = json!({
+ "path": "",
+ "display_description": "Fix error handling",
+ "old_string": "old code",
+ "new_string": "new code"
+ });
+
+ assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
+ }
+
+ #[test]
+ fn still_streaming_ui_text_with_path_and_description() {
+ let tool = EditFileTool;
+ let input = json!({
+ "path": "src/main.rs",
+ "display_description": "Fix error handling",
+ "old_string": "old code",
+ "new_string": "new code"
+ });
+
+ assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
+ }
+
+ #[test]
+ fn still_streaming_ui_text_no_path_or_description() {
+ let tool = EditFileTool;
+ let input = json!({
+ "path": "",
+ "display_description": "",
+ "old_string": "old code",
+ "new_string": "new code"
+ });
+
+ assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+ }
+
+ #[test]
+ fn still_streaming_ui_text_with_null() {
+ let tool = EditFileTool;
+ let input = serde_json::Value::Null;
+
+ assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+ }
+}
@@ -426,6 +426,7 @@ impl Example {
ThreadEvent::ToolConfirmationNeeded => {
panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
},
+ ThreadEvent::StreamedToolUse { .. } |
ThreadEvent::StreamedCompletion |
ThreadEvent::MessageAdded(_) |
ThreadEvent::MessageEdited(_) |
@@ -187,6 +187,7 @@ pub struct LanguageModelToolUse {
pub id: LanguageModelToolUseId,
pub name: Arc<str>,
pub input: serde_json::Value,
+ pub is_input_complete: bool,
}
pub struct LanguageModelTextStream {
@@ -38,6 +38,7 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
+partial-json-fixer.workspace = true
project.workspace = true
proto.workspace = true
schemars.workspace = true
@@ -713,6 +713,35 @@ pub fn map_to_language_model_completion_events(
ContentDelta::InputJsonDelta { partial_json } => {
if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
tool_use.input_json.push_str(&partial_json);
+
+ return Some((
+ vec![maybe!({
+ Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_use.id.clone().into(),
+ name: tool_use.name.clone().into(),
+ is_input_complete: false,
+ input: if tool_use.input_json.is_empty() {
+ serde_json::Value::Object(
+ serde_json::Map::default(),
+ )
+ } else {
+ serde_json::Value::from_str(
+ // Convert invalid (incomplete) JSON into
+ // JSON that serde will accept, e.g. by closing
+ // unclosed delimiters. This way, we can update
+ // the UI with whatever has been streamed back so far.
+ &partial_json_fixer::fix_json(
+ &tool_use.input_json,
+ ),
+ )
+ .map_err(|err| anyhow!(err))?
+ },
+ },
+ ))
+ })],
+ state,
+ ));
}
}
},
@@ -724,6 +753,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
+ is_input_complete: true,
input: if tool_use.input_json.is_empty() {
serde_json::Value::Object(
serde_json::Map::default(),
@@ -893,6 +893,7 @@ pub fn map_to_language_model_completion_events(
let tool_use_event = LanguageModelToolUse {
id: tool_use.id.into(),
name: tool_use.name.into(),
+ is_input_complete: true,
input: if tool_use.input_json.is_empty() {
Value::Null
} else {
@@ -367,6 +367,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
+ is_input_complete: true,
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
@@ -529,6 +529,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id,
name,
+ is_input_complete: true,
input: function_call_part.function_call.args,
},
)));
@@ -490,6 +490,7 @@ pub fn map_to_language_model_completion_events(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.as_str().into(),
+ is_input_complete: true,
input: serde_json::Value::from_str(
&tool_call.arguments,
)?,
@@ -192,6 +192,11 @@ impl Markdown {
self.parse(cx);
}
+ pub fn replace(&mut self, source: impl Into<SharedString>, cx: &mut Context<Self>) {
+ self.source = source.into();
+ self.parse(cx);
+ }
+
pub fn reset(&mut self, source: SharedString, cx: &mut Context<Self>) {
if source == self.source() {
return;