Detailed changes
@@ -6,19 +6,14 @@ mod saved_conversation_picker;
mod tools;
pub mod ui;
-use crate::saved_conversation::{SavedConversation, SavedMessage, SavedMessageRole};
-use crate::saved_conversation_picker::SavedConversationPicker;
-use crate::{
- attachments::ActiveEditorAttachmentTool,
- tools::{CreateBufferTool, ProjectIndexTool},
- ui::UserOrAssistant,
-};
+use crate::ui::UserOrAssistant;
use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
use anyhow::{Context, Result};
use assistant_tooling::{
tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
UserAttachment,
};
+use attachments::ActiveEditorAttachmentTool;
use client::{proto, Client, UserStore};
use collections::HashMap;
use completion_provider::*;
@@ -33,11 +28,13 @@ use gpui::{
use language::{language_settings::SoftWrap, LanguageRegistry};
use open_ai::{FunctionContent, ToolCall, ToolCallContent};
use rich_text::RichText;
+use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
+use saved_conversation_picker::SavedConversationPicker;
use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::sync::Arc;
-use tools::AnnotationTool;
+use tools::{AnnotationTool, CreateBufferTool, ProjectIndexTool};
use ui::{ActiveFileButton, Composer, ProjectIndexButton};
use util::paths::CONVERSATIONS_DIR;
use util::{maybe, paths::EMBEDDINGS_DIR, ResultExt};
@@ -506,13 +503,11 @@ impl AssistantChat {
while let Some(delta) = stream.next().await {
let delta = delta?;
this.update(cx, |this, cx| {
- if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
- messages,
- ..
- })) = this.messages.last_mut()
+ if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
+ this.messages.last_mut()
{
if messages.is_empty() {
- messages.push(AssistantMessage {
+ messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
})
@@ -563,7 +558,7 @@ impl AssistantChat {
let mut tool_tasks = Vec::new();
this.update(cx, |this, cx| {
- if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
+ if let Some(ChatMessage::Assistant(AssistantMessage {
error: message_error,
messages,
..
@@ -592,7 +587,7 @@ impl AssistantChat {
let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
this.update(cx, |this, cx| {
- if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
+ if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
this.messages.last_mut()
{
if let Some(current_message) = messages.last_mut() {
@@ -608,19 +603,19 @@ impl AssistantChat {
fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
// If the last message is a grouped assistant message, add to the grouped message
- if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
+ if let Some(ChatMessage::Assistant(AssistantMessage { messages, .. })) =
self.messages.last_mut()
{
- messages.push(AssistantMessage {
+ messages.push(AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
});
return;
}
- let message = ChatMessage::Assistant(GroupedAssistantMessage {
+ let message = ChatMessage::Assistant(AssistantMessage {
id: self.next_message_id.post_inc(),
- messages: vec![AssistantMessage {
+ messages: vec![AssistantMessagePart {
body: RichText::default(),
tool_calls: Vec::new(),
}],
@@ -669,40 +664,30 @@ impl AssistantChat {
*entry = !*entry;
}
- fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
- let messages = self
- .messages
- .drain(..)
- .map(|message| {
- let text = match &message {
- ChatMessage::User(message) => message.body.read(cx).text(cx),
- ChatMessage::Assistant(message) => message
- .messages
- .iter()
- .map(|message| message.body.text.to_string())
- .collect::<Vec<_>>()
- .join("\n\n"),
- };
-
- SavedMessage {
- id: message.id(),
- role: match message {
- ChatMessage::User(_) => SavedMessageRole::User,
- ChatMessage::Assistant(_) => SavedMessageRole::Assistant,
- },
- text,
- }
- })
- .collect::<Vec<_>>();
-
- // Reset the chat for the new conversation.
+ fn reset(&mut self) {
+ self.messages.clear();
self.list_state.reset(0);
self.editing_message.take();
self.collapsed_messages.clear();
+ }
+
+ fn new_conversation(&mut self, cx: &mut ViewContext<Self>) {
+ let messages = std::mem::take(&mut self.messages)
+ .into_iter()
+ .map(|message| self.serialize_message(message, cx))
+ .collect::<Vec<_>>();
+
+ self.reset();
let title = messages
.first()
- .map(|message| message.text.clone())
+ .map(|message| match message {
+ SavedChatMessage::User { body, .. } => body.clone(),
+ SavedChatMessage::Assistant { messages, .. } => messages
+ .first()
+ .map(|message| message.body.to_string())
+ .unwrap_or_default(),
+ })
.unwrap_or_else(|| "A conversation with the assistant.".to_string());
let saved_conversation = SavedConversation {
@@ -836,7 +821,7 @@ impl AssistantChat {
}
})
.into_any(),
- ChatMessage::Assistant(GroupedAssistantMessage {
+ ChatMessage::Assistant(AssistantMessage {
id,
messages,
error,
@@ -917,7 +902,7 @@ impl AssistantChat {
content: body.read(cx).text(cx),
});
}
- ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
+ ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
for message in messages {
let body = message.body.clone();
@@ -971,6 +956,43 @@ impl AssistantChat {
Ok(completion_messages)
})
}
+
+ fn serialize_message(
+ &self,
+ message: ChatMessage,
+ cx: &mut ViewContext<AssistantChat>,
+ ) -> SavedChatMessage {
+ match message {
+ ChatMessage::User(message) => SavedChatMessage::User {
+ id: message.id,
+ body: message.body.read(cx).text(cx),
+ attachments: message
+ .attachments
+ .iter()
+ .map(|attachment| {
+ self.attachment_registry
+ .serialize_user_attachment(attachment)
+ })
+ .collect(),
+ },
+ ChatMessage::Assistant(message) => SavedChatMessage::Assistant {
+ id: message.id,
+ error: message.error,
+ messages: message
+ .messages
+ .iter()
+ .map(|message| SavedAssistantMessagePart {
+ body: message.body.text.clone(),
+ tool_calls: message
+ .tool_calls
+ .iter()
+ .map(|tool_call| self.tool_registry.serialize_tool_call(tool_call))
+ .collect(),
+ })
+ .collect(),
+ },
+ }
+ }
}
impl Render for AssistantChat {
@@ -1053,17 +1075,10 @@ impl MessageId {
enum ChatMessage {
User(UserMessage),
- Assistant(GroupedAssistantMessage),
+ Assistant(AssistantMessage),
}
impl ChatMessage {
- pub fn id(&self) -> MessageId {
- match self {
- ChatMessage::User(message) => message.id,
- ChatMessage::Assistant(message) => message.id,
- }
- }
-
fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
match self {
ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
@@ -1073,18 +1088,18 @@ impl ChatMessage {
}
struct UserMessage {
- id: MessageId,
- body: View<Editor>,
- attachments: Vec<UserAttachment>,
+ pub id: MessageId,
+ pub body: View<Editor>,
+ pub attachments: Vec<UserAttachment>,
}
-struct AssistantMessage {
- body: RichText,
- tool_calls: Vec<ToolFunctionCall>,
+struct AssistantMessagePart {
+ pub body: RichText,
+ pub tool_calls: Vec<ToolFunctionCall>,
}
-struct GroupedAssistantMessage {
- id: MessageId,
- messages: Vec<AssistantMessage>,
- error: Option<SharedString>,
+struct AssistantMessage {
+ pub id: MessageId,
+ pub messages: Vec<AssistantMessagePart>,
+ pub error: Option<SharedString>,
}
@@ -1,64 +1,68 @@
+use std::{path::PathBuf, sync::Arc};
+
use anyhow::{anyhow, Result};
use assistant_tooling::{LanguageModelAttachment, ProjectContext, ToolOutput};
use editor::Editor;
use gpui::{Render, Task, View, WeakModel, WeakView};
use language::Buffer;
use project::ProjectPath;
+use serde::{Deserialize, Serialize};
use ui::{prelude::*, ButtonLike, Tooltip, WindowContext};
use util::maybe;
use workspace::Workspace;
+#[derive(Serialize, Deserialize)]
pub struct ActiveEditorAttachment {
- buffer: WeakModel<Buffer>,
- path: Option<ProjectPath>,
+ #[serde(skip)]
+ buffer: Option<WeakModel<Buffer>>,
+ path: Option<PathBuf>,
}
pub struct FileAttachmentView {
- output: Result<ActiveEditorAttachment>,
+ project_path: Option<ProjectPath>,
+ buffer: Option<WeakModel<Buffer>>,
+ error: Option<anyhow::Error>,
}
impl Render for FileAttachmentView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- match &self.output {
- Ok(attachment) => {
- let filename: SharedString = attachment
- .path
- .as_ref()
- .and_then(|p| p.path.file_name()?.to_str())
- .unwrap_or("Untitled")
- .to_string()
- .into();
-
- // todo!(): make the button link to the actual file to open
- ButtonLike::new("file-attachment")
- .child(
- h_flex()
- .gap_1()
- .bg(cx.theme().colors().editor_background)
- .rounded_md()
- .child(ui::Icon::new(IconName::File))
- .child(filename.clone()),
- )
- .tooltip({
- move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx)
- })
- .into_any_element()
- }
- Err(err) => div().child(err.to_string()).into_any_element(),
+ if let Some(error) = &self.error {
+ return div().child(error.to_string()).into_any_element();
}
+
+ let filename: SharedString = self
+ .project_path
+ .as_ref()
+ .and_then(|p| p.path.file_name()?.to_str())
+ .unwrap_or("Untitled")
+ .to_string()
+ .into();
+
+ ButtonLike::new("file-attachment")
+ .child(
+ h_flex()
+ .gap_1()
+ .bg(cx.theme().colors().editor_background)
+ .rounded_md()
+ .child(ui::Icon::new(IconName::File))
+ .child(filename.clone()),
+ )
+ .tooltip(move |cx| Tooltip::with_meta("File Attached", None, filename.clone(), cx))
+ .into_any_element()
}
}
impl ToolOutput for FileAttachmentView {
fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String {
- if let Ok(result) = &self.output {
- if let Some(path) = &result.path {
- project.add_file(path.clone());
- return format!("current file: {}", path.path.display());
- } else if let Some(buffer) = result.buffer.upgrade() {
- return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
- }
+ if let Some(path) = &self.project_path {
+ project.add_file(path.clone());
+ return format!("current file: {}", path.path.display());
+ }
+
+ if let Some(buffer) = self.buffer.as_ref().and_then(|buffer| buffer.upgrade()) {
+ return format!("current untitled buffer text:\n{}", buffer.read(cx).text());
}
+
String::new()
}
}
@@ -77,6 +81,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
type Output = ActiveEditorAttachment;
type View = FileAttachmentView;
+ fn name(&self) -> Arc<str> {
+ "active-editor-attachment".into()
+ }
+
fn run(&self, cx: &mut WindowContext) -> Task<Result<ActiveEditorAttachment>> {
Task::ready(maybe!({
let active_buffer = self
@@ -91,13 +99,10 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
let buffer = active_buffer.read(cx);
if let Some(buffer) = buffer.as_singleton() {
- let path =
- project::File::from_dyn(buffer.read(cx).file()).map(|file| ProjectPath {
- worktree_id: file.worktree_id(cx),
- path: file.path.clone(),
- });
+ let path = project::File::from_dyn(buffer.read(cx).file())
+ .and_then(|file| file.worktree.read(cx).absolutize(&file.path).ok());
return Ok(ActiveEditorAttachment {
- buffer: buffer.downgrade(),
+ buffer: Some(buffer.downgrade()),
path,
});
} else {
@@ -106,7 +111,34 @@ impl LanguageModelAttachment for ActiveEditorAttachmentTool {
}))
}
- fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View> {
- cx.new_view(|_cx| FileAttachmentView { output })
+ fn view(
+ &self,
+ output: Result<ActiveEditorAttachment>,
+ cx: &mut WindowContext,
+ ) -> View<Self::View> {
+ let error;
+ let project_path;
+ let buffer;
+ match output {
+ Ok(output) => {
+ error = None;
+ let workspace = self.workspace.upgrade().unwrap();
+ let project = workspace.read(cx).project();
+ project_path = output
+ .path
+ .and_then(|path| project.read(cx).project_path_for_absolute_path(&path, cx));
+ buffer = output.buffer;
+ }
+ Err(err) => {
+ error = Some(err);
+ buffer = None;
+ project_path = None;
+ }
+ }
+ cx.new_view(|_cx| FileAttachmentView {
+ project_path,
+ buffer,
+ error,
+ })
}
}
@@ -1,3 +1,5 @@
+use assistant_tooling::{SavedToolFunctionCall, SavedUserAttachment};
+use gpui::SharedString;
use serde::{Deserialize, Serialize};
use crate::MessageId;
@@ -8,21 +10,27 @@ pub struct SavedConversation {
pub version: String,
/// The title of the conversation, generated by the Assistant.
pub title: String,
- pub messages: Vec<SavedMessage>,
+ pub messages: Vec<SavedChatMessage>,
}
#[derive(Serialize, Deserialize)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum SavedMessageRole {
- User,
- Assistant,
+pub enum SavedChatMessage {
+ User {
+ id: MessageId,
+ body: String,
+ attachments: Vec<SavedUserAttachment>,
+ },
+ Assistant {
+ id: MessageId,
+ messages: Vec<SavedAssistantMessagePart>,
+ error: Option<SharedString>,
+ },
}
#[derive(Serialize, Deserialize)]
-pub struct SavedMessage {
- pub id: MessageId,
- pub role: SavedMessageRole,
- pub text: String,
+pub struct SavedAssistantMessagePart {
+ pub body: SharedString,
+ pub tool_calls: Vec<SavedToolFunctionCall>,
}
/// Returns a list of placeholder conversations for mocking the UI.
@@ -6,7 +6,7 @@ use editor::{
};
use gpui::{prelude::*, AnyElement, Model, Task, View, WeakView};
use language::ToPoint;
-use project::{Project, ProjectPath};
+use project::{search::SearchQuery, Project, ProjectPath};
use schemars::JsonSchema;
use serde::Deserialize;
use std::path::Path;
@@ -29,17 +29,18 @@ impl AnnotationTool {
pub struct AnnotationInput {
/// Name for this set of annotations
title: String,
- annotations: Vec<Annotation>,
+ /// Excerpts from the file to show to the user.
+ excerpts: Vec<Excerpt>,
}
#[derive(Debug, Deserialize, JsonSchema, Clone)]
-struct Annotation {
+struct Excerpt {
/// Path to the file
path: String,
- /// Name of a symbol in the code
- symbol_name: String,
- /// Text to display near the symbol definition
- text: String,
+ /// A short, distinctive string that appears in the file, used to define a location in the file.
+ text_passage: String,
+ /// Text to display above the code excerpt
+ annotation: String,
}
impl LanguageModelTool for AnnotationTool {
@@ -58,7 +59,7 @@ impl LanguageModelTool for AnnotationTool {
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
let workspace = self.workspace.clone();
let project = self.project.clone();
- let excerpts = input.annotations.clone();
+ let excerpts = input.excerpts.clone();
let title = input.title.clone();
let worktree_id = project.update(cx, |project, cx| {
@@ -74,15 +75,16 @@ impl LanguageModelTool for AnnotationTool {
};
let buffer_tasks = project.update(cx, |project, cx| {
- let excerpts = excerpts.clone();
excerpts
.iter()
.map(|excerpt| {
- let project_path = ProjectPath {
- worktree_id,
- path: Path::new(&excerpt.path).into(),
- };
- project.open_buffer(project_path.clone(), cx)
+ project.open_buffer(
+ ProjectPath {
+ worktree_id,
+ path: Path::new(&excerpt.path).into(),
+ },
+ cx,
+ )
})
.collect::<Vec<_>>()
});
@@ -99,39 +101,43 @@ impl LanguageModelTool for AnnotationTool {
for (excerpt, buffer) in excerpts.iter().zip(buffers.iter()) {
let snapshot = buffer.update(&mut cx, |buffer, _cx| buffer.snapshot())?;
- if let Some(outline) = snapshot.outline(None) {
- let matches = outline
- .search(&excerpt.symbol_name, cx.background_executor().clone())
- .await;
- if let Some(mat) = matches.first() {
- let item = &outline.items[mat.candidate_id];
- let start = item.range.start.to_point(&snapshot);
- editor.update(&mut cx, |editor, cx| {
- let ranges = editor.buffer().update(cx, |multibuffer, cx| {
- multibuffer.push_excerpts_with_context_lines(
- buffer.clone(),
- vec![start..start],
- 5,
- cx,
- )
- });
- let explanation = SharedString::from(excerpt.text.clone());
- editor.insert_blocks(
- [BlockProperties {
- position: ranges[0].start,
- height: 2,
- style: BlockStyle::Fixed,
- render: Box::new(move |cx| {
- Self::render_note_block(&explanation, cx)
- }),
- disposition: BlockDisposition::Above,
- }],
- None,
- cx,
- );
- })?;
- }
- }
+ let query =
+ SearchQuery::text(&excerpt.text_passage, false, false, false, vec![], vec![])?;
+
+ let matches = query.search(&snapshot, None).await;
+ let Some(first_match) = matches.first() else {
+ log::warn!(
+ "text {:?} does not appear in '{}'",
+ excerpt.text_passage,
+ excerpt.path
+ );
+ continue;
+ };
+ let mut start = first_match.start.to_point(&snapshot);
+ start.column = 0;
+
+ editor.update(&mut cx, |editor, cx| {
+ let ranges = editor.buffer().update(cx, |multibuffer, cx| {
+ multibuffer.push_excerpts_with_context_lines(
+ buffer.clone(),
+ vec![start..start],
+ 5,
+ cx,
+ )
+ });
+ let annotation = SharedString::from(excerpt.annotation.clone());
+ editor.insert_blocks(
+ [BlockProperties {
+ position: ranges[0].start,
+ height: annotation.split('\n').count() as u8 + 1,
+ style: BlockStyle::Fixed,
+ render: Box::new(move |cx| Self::render_note_block(&annotation, cx)),
+ disposition: BlockDisposition::Above,
+ }],
+ None,
+ cx,
+ );
+ })?;
}
workspace
@@ -144,7 +150,8 @@ impl LanguageModelTool for AnnotationTool {
})
}
- fn output_view(
+ fn view(
+ &self,
_: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
@@ -86,7 +86,8 @@ impl LanguageModelTool for CreateBufferTool {
})
}
- fn output_view(
+ fn view(
+ &self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
@@ -1,13 +1,13 @@
-use anyhow::Result;
+use anyhow::{anyhow, Result};
use assistant_tooling::{LanguageModelTool, ToolOutput};
use collections::BTreeMap;
use gpui::{prelude::*, Model, Task};
use project::ProjectPath;
use schemars::JsonSchema;
use semantic_index::{ProjectIndex, Status};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
use serde_json::Value;
-use std::{fmt::Write as _, ops::Range};
+use std::{fmt::Write as _, ops::Range, path::Path, sync::Arc};
use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
const DEFAULT_SEARCH_LIMIT: usize = 20;
@@ -29,28 +29,24 @@ pub struct CodebaseQuery {
pub struct ProjectIndexView {
input: CodebaseQuery,
- output: Result<ProjectIndexOutput>,
+ status: Status,
+ excerpts: Result<BTreeMap<ProjectPath, Vec<Range<usize>>>>,
element_id: ElementId,
expanded_header: bool,
}
+#[derive(Serialize, Deserialize)]
pub struct ProjectIndexOutput {
status: Status,
- excerpts: BTreeMap<ProjectPath, Vec<Range<usize>>>,
+ worktrees: BTreeMap<Arc<Path>, WorktreeIndexOutput>,
}
-impl ProjectIndexView {
- fn new(input: CodebaseQuery, output: Result<ProjectIndexOutput>) -> Self {
- let element_id = ElementId::Name(nanoid::nanoid!().into());
-
- Self {
- input,
- output,
- element_id,
- expanded_header: false,
- }
- }
+#[derive(Serialize, Deserialize)]
+struct WorktreeIndexOutput {
+ excerpts: BTreeMap<Arc<Path>, Vec<Range<usize>>>,
+}
+impl ProjectIndexView {
fn toggle_header(&mut self, cx: &mut ViewContext<Self>) {
self.expanded_header = !self.expanded_header;
cx.notify();
@@ -60,18 +56,14 @@ impl ProjectIndexView {
impl Render for ProjectIndexView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let query = self.input.query.clone();
-
- let result = &self.output;
-
- let output = match result {
+ let excerpts = match &self.excerpts {
Err(err) => {
return div().child(Label::new(format!("Error: {}", err)).color(Color::Error));
}
- Ok(output) => output,
+ Ok(excerpts) => excerpts,
};
- let file_count = output.excerpts.len();
-
+ let file_count = excerpts.len();
let header = h_flex()
.gap_2()
.child(Icon::new(IconName::File))
@@ -97,16 +89,12 @@ impl Render for ProjectIndexView {
.child(Icon::new(IconName::MagnifyingGlass))
.child(Label::new(format!("`{}`", query)).color(Color::Muted)),
)
- .child(
- v_flex()
- .gap_2()
- .children(output.excerpts.keys().map(|path| {
- h_flex().gap_2().child(Icon::new(IconName::File)).child(
- Label::new(path.path.to_string_lossy().to_string())
- .color(Color::Muted),
- )
- })),
- ),
+ .child(v_flex().gap_2().children(excerpts.keys().map(|path| {
+ h_flex().gap_2().child(Icon::new(IconName::File)).child(
+ Label::new(path.path.to_string_lossy().to_string())
+ .color(Color::Muted),
+ )
+ }))),
),
)
}
@@ -118,16 +106,16 @@ impl ToolOutput for ProjectIndexView {
context: &mut assistant_tooling::ProjectContext,
_: &mut WindowContext,
) -> String {
- match &self.output {
- Ok(output) => {
+ match &self.excerpts {
+ Ok(excerpts) => {
let mut body = "found results in the following paths:\n".to_string();
- for (project_path, ranges) in &output.excerpts {
+ for (project_path, ranges) in excerpts {
context.add_excerpts(project_path.clone(), ranges);
writeln!(&mut body, "* {}", &project_path.path.display()).unwrap();
}
- if output.status != Status::Idle {
+ if self.status != Status::Idle {
body.push_str("Still indexing. Results may be incomplete.\n");
}
@@ -172,16 +160,20 @@ impl LanguageModelTool for ProjectIndexTool {
cx.update(|cx| {
let mut output = ProjectIndexOutput {
status,
- excerpts: Default::default(),
+ worktrees: Default::default(),
};
for search_result in search_results {
- let path = ProjectPath {
- worktree_id: search_result.worktree.read(cx).id(),
- path: search_result.path.clone(),
- };
-
- let excerpts_for_path = output.excerpts.entry(path).or_default();
+ let worktree_path = search_result.worktree.read(cx).abs_path();
+ let excerpts = &mut output
+ .worktrees
+ .entry(worktree_path)
+ .or_insert(WorktreeIndexOutput {
+ excerpts: Default::default(),
+ })
+ .excerpts;
+
+ let excerpts_for_path = excerpts.entry(search_result.path).or_default();
let ix = match excerpts_for_path
.binary_search_by_key(&search_result.range.start, |r| r.start)
{
@@ -195,12 +187,57 @@ impl LanguageModelTool for ProjectIndexTool {
})
}
- fn output_view(
+ fn view(
+ &self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
) -> gpui::View<Self::View> {
- cx.new_view(|_cx| ProjectIndexView::new(input, output))
+ cx.new_view(|cx| {
+ let status;
+ let excerpts;
+ match output {
+ Ok(output) => {
+ status = output.status;
+ let project_index = self.project_index.read(cx);
+ if let Some(project) = project_index.project().upgrade() {
+ let project = project.read(cx);
+ excerpts = Ok(output
+ .worktrees
+ .into_iter()
+ .filter_map(|(abs_path, output)| {
+ for worktree in project.worktrees() {
+ let worktree = worktree.read(cx);
+ if worktree.abs_path() == abs_path {
+ return Some((worktree.id(), output.excerpts));
+ }
+ }
+ None
+ })
+ .flat_map(|(worktree_id, excerpts)| {
+ excerpts.into_iter().map(move |(path, ranges)| {
+ (ProjectPath { worktree_id, path }, ranges)
+ })
+ })
+ .collect::<BTreeMap<_, _>>());
+ } else {
+ excerpts = Err(anyhow!("project was dropped"));
+ }
+ }
+ Err(err) => {
+ status = Status::Idle;
+ excerpts = Err(err);
+ }
+ };
+
+ ProjectIndexView {
+ input,
+ status,
+ excerpts,
+ element_id: ElementId::Name(nanoid::nanoid!().into()),
+ expanded_header: false,
+ }
+ })
}
fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {
@@ -2,9 +2,12 @@ mod attachment_registry;
mod project_context;
mod tool_registry;
-pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
+pub use attachment_registry::{
+ AttachmentRegistry, LanguageModelAttachment, SavedUserAttachment, UserAttachment,
+};
pub use project_context::ProjectContext;
pub use tool_registry::{
- tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition,
+ tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall,
+ SavedToolFunctionCallResult, ToolFunctionCall, ToolFunctionCallResult, ToolFunctionDefinition,
ToolOutput, ToolRegistry,
};
@@ -3,6 +3,8 @@ use anyhow::{anyhow, Result};
use collections::HashMap;
use futures::future::join_all;
use gpui::{AnyView, Render, Task, View, WindowContext};
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
+use serde_json::value::RawValue;
use std::{
any::TypeId,
sync::{
@@ -17,24 +19,34 @@ pub struct AttachmentRegistry {
}
pub trait LanguageModelAttachment {
- type Output: 'static;
+ type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
+ fn name(&self) -> Arc<str>;
fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
-
- fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
+ fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
}
/// A collected attachment from running an attachment tool
pub struct UserAttachment {
pub view: AnyView,
+ name: Arc<str>,
+ serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
}
+#[derive(Serialize, Deserialize)]
+pub struct SavedUserAttachment {
+ name: Arc<str>,
+ serialized_output: Result<Box<RawValue>, String>,
+}
+
/// Internal representation of an attachment tool to allow us to treat them dynamically
struct RegisteredAttachment {
+ name: Arc<str>,
enabled: AtomicBool,
call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
+ deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
}
impl AttachmentRegistry {
@@ -45,24 +57,65 @@ impl AttachmentRegistry {
}
pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
- let call = Box::new(move |cx: &mut WindowContext| {
- let result = attachment.run(cx);
+ let attachment = Arc::new(attachment);
+
+ let call = Box::new({
+ let attachment = attachment.clone();
+ move |cx: &mut WindowContext| {
+ let result = attachment.run(cx);
+ let attachment = attachment.clone();
+ cx.spawn(move |mut cx| async move {
+ let result: Result<A::Output> = result.await;
+ let serialized_output =
+ result
+ .as_ref()
+ .map_err(ToString::to_string)
+ .and_then(|output| {
+ Ok(RawValue::from_string(
+ serde_json::to_string(output).map_err(|e| e.to_string())?,
+ )
+ .unwrap())
+ });
+
+ let view = cx.update(|cx| attachment.view(result, cx))?;
+
+ Ok(UserAttachment {
+ name: attachment.name(),
+ view: view.into(),
+ generate_fn: generate::<A>,
+ serialized_output,
+ })
+ })
+ }
+ });
- cx.spawn(move |mut cx| async move {
- let result: Result<A::Output> = result.await;
- let view = cx.update(|cx| A::view(result, cx))?;
+ let deserialize = Box::new({
+ let attachment = attachment.clone();
+ move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
+ let serialized_output = saved_attachment.serialized_output.clone();
+ let output = match &serialized_output {
+ Ok(serialized_output) => {
+ Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
+ }
+ Err(error) => Err(anyhow!("{error}")),
+ };
+ let view = attachment.view(output, cx).into();
Ok(UserAttachment {
- view: view.into(),
+ name: saved_attachment.name.clone(),
+ view,
+ serialized_output,
generate_fn: generate::<A>,
})
- })
+ }
});
self.registered_attachments.insert(
TypeId::of::<A>(),
RegisteredAttachment {
+ name: attachment.name(),
call,
+ deserialize,
enabled: AtomicBool::new(true),
},
);
@@ -134,6 +187,35 @@ impl AttachmentRegistry {
.collect())
})
}
+
+ pub fn serialize_user_attachment(
+ &self,
+ user_attachment: &UserAttachment,
+ ) -> SavedUserAttachment {
+ SavedUserAttachment {
+ name: user_attachment.name.clone(),
+ serialized_output: user_attachment.serialized_output.clone(),
+ }
+ }
+
+ pub fn deserialize_user_attachment(
+ &self,
+ saved_user_attachment: SavedUserAttachment,
+ cx: &mut WindowContext,
+ ) -> Result<UserAttachment> {
+ if let Some(registered_attachment) = self
+ .registered_attachments
+ .values()
+ .find(|attachment| attachment.name == saved_user_attachment.name)
+ {
+ (registered_attachment.deserialize)(&saved_user_attachment, cx)
+ } else {
+ Err(anyhow!(
+ "no attachment tool for name {}",
+ saved_user_attachment.name
+ ))
+ }
+ }
}
impl UserAttachment {
@@ -1,41 +1,60 @@
+use crate::ProjectContext;
use anyhow::{anyhow, Result};
use gpui::{
div, AnyElement, AnyView, IntoElement, ParentElement, Render, Styled, Task, View, WindowContext,
};
use schemars::{schema::RootSchema, schema_for, JsonSchema};
-use serde::Deserialize;
-use serde_json::Value;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
+use serde_json::{value::RawValue, Value};
use std::{
any::TypeId,
collections::HashMap,
fmt::Display,
- sync::atomic::{AtomicBool, Ordering::SeqCst},
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ },
};
-use crate::ProjectContext;
-
pub struct ToolRegistry {
registered_tools: HashMap<String, RegisteredTool>,
}
-#[derive(Default, Deserialize)]
+#[derive(Default)]
pub struct ToolFunctionCall {
pub id: String,
pub name: String,
pub arguments: String,
- #[serde(skip)]
pub result: Option<ToolFunctionCallResult>,
}
+#[derive(Default, Serialize, Deserialize)]
+pub struct SavedToolFunctionCall {
+ pub id: String,
+ pub name: String,
+ pub arguments: String,
+ pub result: Option<SavedToolFunctionCallResult>,
+}
+
pub enum ToolFunctionCallResult {
NoSuchTool,
ParsingFailed,
Finished {
view: AnyView,
+ serialized_output: Result<Box<RawValue>, String>,
generate_fn: fn(AnyView, &mut ProjectContext, &mut WindowContext) -> String,
},
}
+#[derive(Serialize, Deserialize)]
+pub enum SavedToolFunctionCallResult {
+ NoSuchTool,
+ ParsingFailed,
+ Finished {
+ serialized_output: Result<Box<RawValue>, String>,
+ },
+}
+
#[derive(Clone)]
pub struct ToolFunctionDefinition {
pub name: String,
@@ -46,10 +65,10 @@ pub struct ToolFunctionDefinition {
pub trait LanguageModelTool {
/// The input type that will be passed in to `execute` when the tool is called
/// by the language model.
- type Input: for<'de> Deserialize<'de> + JsonSchema;
+ type Input: DeserializeOwned + JsonSchema;
/// The output returned by executing the tool.
- type Output: 'static;
+ type Output: DeserializeOwned + Serialize + 'static;
type View: Render + ToolOutput;
@@ -80,7 +99,8 @@ pub trait LanguageModelTool {
fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
/// A view of the output of running the tool, for displaying to the user.
- fn output_view(
+ fn view(
+ &self,
input: Self::Input,
output: Result<Self::Output>,
cx: &mut WindowContext,
@@ -102,7 +122,8 @@ pub trait ToolOutput: Sized {
struct RegisteredTool {
enabled: AtomicBool,
type_id: TypeId,
- call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+ execute: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
+ deserialize: Box<dyn Fn(&SavedToolFunctionCall, &mut WindowContext) -> ToolFunctionCall>,
render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
definition: ToolFunctionDefinition,
}
@@ -162,23 +183,125 @@ impl ToolRegistry {
}
}
+ pub fn serialize_tool_call(&self, call: &ToolFunctionCall) -> SavedToolFunctionCall {
+ SavedToolFunctionCall {
+ id: call.id.clone(),
+ name: call.name.clone(),
+ arguments: call.arguments.clone(),
+ result: call.result.as_ref().map(|result| match result {
+ ToolFunctionCallResult::NoSuchTool => SavedToolFunctionCallResult::NoSuchTool,
+ ToolFunctionCallResult::ParsingFailed => SavedToolFunctionCallResult::ParsingFailed,
+ ToolFunctionCallResult::Finished {
+ serialized_output, ..
+ } => SavedToolFunctionCallResult::Finished {
+ serialized_output: match serialized_output {
+ Ok(value) => Ok(value.clone()),
+ Err(e) => Err(e.to_string()),
+ },
+ },
+ }),
+ }
+ }
+
+ pub fn deserialize_tool_call(
+ &self,
+ call: &SavedToolFunctionCall,
+ cx: &mut WindowContext,
+ ) -> ToolFunctionCall {
+ if let Some(tool) = &self.registered_tools.get(&call.name) {
+ (tool.deserialize)(call, cx)
+ } else {
+ ToolFunctionCall {
+ id: call.id.clone(),
+ name: call.name.clone(),
+ arguments: call.arguments.clone(),
+ result: Some(ToolFunctionCallResult::NoSuchTool),
+ }
+ }
+ }
+
pub fn register<T: 'static + LanguageModelTool>(
&mut self,
tool: T,
_cx: &mut WindowContext,
) -> Result<()> {
let name = tool.name();
+ let tool = Arc::new(tool);
let registered_tool = RegisteredTool {
type_id: TypeId::of::<T>(),
definition: tool.definition(),
enabled: AtomicBool::new(true),
- call: Box::new(
- move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
+ deserialize: Box::new({
+ let tool = tool.clone();
+ move |tool_call: &SavedToolFunctionCall, cx: &mut WindowContext| {
+ let id = tool_call.id.clone();
let name = tool_call.name.clone();
let arguments = tool_call.arguments.clone();
+
+ let Ok(input) = serde_json::from_str::<T::Input>(&tool_call.arguments) else {
+ return ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(ToolFunctionCallResult::ParsingFailed),
+ };
+ };
+
+ let result = match &tool_call.result {
+ Some(result) => match result {
+ SavedToolFunctionCallResult::NoSuchTool => {
+ Some(ToolFunctionCallResult::NoSuchTool)
+ }
+ SavedToolFunctionCallResult::ParsingFailed => {
+ Some(ToolFunctionCallResult::ParsingFailed)
+ }
+ SavedToolFunctionCallResult::Finished { serialized_output } => {
+ let output = match serialized_output {
+ Ok(value) => {
+ match serde_json::from_str::<T::Output>(value.get()) {
+ Ok(value) => Ok(value),
+ Err(_) => {
+ return ToolFunctionCall {
+ id,
+ name: name.clone(),
+ arguments,
+ result: Some(
+ ToolFunctionCallResult::ParsingFailed,
+ ),
+ };
+ }
+ }
+ }
+ Err(e) => Err(anyhow!("{e}")),
+ };
+
+ let view = tool.view(input, output, cx).into();
+ Some(ToolFunctionCallResult::Finished {
+ serialized_output: serialized_output.clone(),
+ generate_fn: generate::<T>,
+ view,
+ })
+ }
+ },
+ None => None,
+ };
+
+ ToolFunctionCall {
+ id: tool_call.id.clone(),
+ name: name.clone(),
+ arguments: tool_call.arguments.clone(),
+ result,
+ }
+ }
+ }),
+ execute: Box::new({
+ let tool = tool.clone();
+ move |tool_call: &ToolFunctionCall, cx: &mut WindowContext| {
let id = tool_call.id.clone();
+ let name = tool_call.name.clone();
+ let arguments = tool_call.arguments.clone();
- let Ok(input) = serde_json::from_str::<T::Input>(arguments.as_str()) else {
+ let Ok(input) = serde_json::from_str::<T::Input>(&arguments) else {
return Task::ready(Ok(ToolFunctionCall {
id,
name: name.clone(),
@@ -188,23 +311,33 @@ impl ToolRegistry {
};
let result = tool.execute(&input, cx);
-
+ let tool = tool.clone();
cx.spawn(move |mut cx| async move {
- let result: Result<T::Output> = result.await;
- let view = cx.update(|cx| T::output_view(input, result, cx))?;
+ let result = result.await;
+ let serialized_output = result
+ .as_ref()
+ .map_err(ToString::to_string)
+ .and_then(|output| {
+ Ok(RawValue::from_string(
+ serde_json::to_string(output).map_err(|e| e.to_string())?,
+ )
+ .unwrap())
+ });
+ let view = cx.update(|cx| tool.view(input, result, cx))?;
Ok(ToolFunctionCall {
id,
name: name.clone(),
arguments,
result: Some(ToolFunctionCallResult::Finished {
+ serialized_output,
view: view.into(),
generate_fn: generate::<T>,
}),
})
})
- },
- ),
+ }
+ }),
render_running: render_running::<T>,
};
@@ -259,7 +392,7 @@ impl ToolRegistry {
}
};
- (tool.call)(tool_call, cx)
+ (tool.execute)(tool_call, cx)
}
}
@@ -275,9 +408,9 @@ impl ToolFunctionCallResult {
ToolFunctionCallResult::ParsingFailed => {
format!("Unable to parse arguments for {name}")
}
- ToolFunctionCallResult::Finished { generate_fn, view } => {
- (generate_fn)(view.clone(), project, cx)
- }
+ ToolFunctionCallResult::Finished {
+ generate_fn, view, ..
+ } => (generate_fn)(view.clone(), project, cx),
}
}
@@ -373,7 +506,8 @@ mod test {
Task::ready(Ok(weather))
}
- fn output_view(
+ fn view(
+ &self,
_input: Self::Input,
result: Result<Self::Output>,
cx: &mut WindowContext,
@@ -7864,6 +7864,18 @@ impl Project {
})
}
+ pub fn project_path_for_absolute_path(
+ &self,
+ abs_path: &Path,
+ cx: &AppContext,
+ ) -> Option<ProjectPath> {
+ self.find_local_worktree(abs_path, cx)
+ .map(|(worktree, relative_path)| ProjectPath {
+ worktree_id: worktree.read(cx).id(),
+ path: relative_path.into(),
+ })
+ }
+
pub fn get_workspace_root(
&self,
project_path: &ProjectPath,
@@ -250,6 +250,7 @@ impl SearchQuery {
}
}
}
+
pub async fn search(
&self,
buffer: &BufferSnapshot,
@@ -450,7 +450,7 @@ pub struct WorktreeSearchResult {
pub score: f32,
}
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum Status {
Idle,
Loading,