1use crate::{ProjectContext, ToolOutput};
2use anyhow::{anyhow, Result};
3use collections::HashMap;
4use futures::future::join_all;
5use gpui::{AnyView, Render, Task, View, WindowContext};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use serde_json::value::RawValue;
8use std::{
9 any::TypeId,
10 sync::{
11 atomic::{AtomicBool, Ordering::SeqCst},
12 Arc,
13 },
14};
15use util::ResultExt as _;
16
17pub struct AttachmentRegistry {
18 registered_attachments: HashMap<TypeId, RegisteredAttachment>,
19}
20
21pub trait LanguageModelAttachment {
22 type Output: DeserializeOwned + Serialize + 'static;
23 type View: Render + ToolOutput;
24
25 fn name(&self) -> Arc<str>;
26 fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
27 fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
28}
29
30/// A collected attachment from running an attachment tool
31pub struct UserAttachment {
32 pub view: AnyView,
33 name: Arc<str>,
34 serialized_output: Result<Box<RawValue>, String>,
35 generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
36}
37
38#[derive(Serialize, Deserialize)]
39pub struct SavedUserAttachment {
40 name: Arc<str>,
41 serialized_output: Result<Box<RawValue>, String>,
42}
43
44/// Internal representation of an attachment tool to allow us to treat them dynamically
45struct RegisteredAttachment {
46 name: Arc<str>,
47 enabled: AtomicBool,
48 call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
49 deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
50}
51
52impl AttachmentRegistry {
53 pub fn new() -> Self {
54 Self {
55 registered_attachments: HashMap::default(),
56 }
57 }
58
59 pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
60 let attachment = Arc::new(attachment);
61
62 let call = Box::new({
63 let attachment = attachment.clone();
64 move |cx: &mut WindowContext| {
65 let result = attachment.run(cx);
66 let attachment = attachment.clone();
67 cx.spawn(move |mut cx| async move {
68 let result: Result<A::Output> = result.await;
69 let serialized_output =
70 result
71 .as_ref()
72 .map_err(ToString::to_string)
73 .and_then(|output| {
74 Ok(RawValue::from_string(
75 serde_json::to_string(output).map_err(|e| e.to_string())?,
76 )
77 .unwrap())
78 });
79
80 let view = cx.update(|cx| attachment.view(result, cx))?;
81
82 Ok(UserAttachment {
83 name: attachment.name(),
84 view: view.into(),
85 generate_fn: generate::<A>,
86 serialized_output,
87 })
88 })
89 }
90 });
91
92 let deserialize = Box::new({
93 let attachment = attachment.clone();
94 move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
95 let serialized_output = saved_attachment.serialized_output.clone();
96 let output = match &serialized_output {
97 Ok(serialized_output) => {
98 Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
99 }
100 Err(error) => Err(anyhow!("{error}")),
101 };
102 let view = attachment.view(output, cx).into();
103
104 Ok(UserAttachment {
105 name: saved_attachment.name.clone(),
106 view,
107 serialized_output,
108 generate_fn: generate::<A>,
109 })
110 }
111 });
112
113 self.registered_attachments.insert(
114 TypeId::of::<A>(),
115 RegisteredAttachment {
116 name: attachment.name(),
117 call,
118 deserialize,
119 enabled: AtomicBool::new(true),
120 },
121 );
122 return;
123
124 fn generate<T: LanguageModelAttachment>(
125 view: AnyView,
126 project: &mut ProjectContext,
127 cx: &mut WindowContext,
128 ) -> String {
129 view.downcast::<T::View>()
130 .unwrap()
131 .update(cx, |view, cx| T::View::generate(view, project, cx))
132 }
133 }
134
135 pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
136 &self,
137 is_enabled: bool,
138 ) {
139 if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
140 attachment.enabled.store(is_enabled, SeqCst);
141 }
142 }
143
144 pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
145 if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
146 attachment.enabled.load(SeqCst)
147 } else {
148 false
149 }
150 }
151
152 pub fn call<A: LanguageModelAttachment + 'static>(
153 &self,
154 cx: &mut WindowContext,
155 ) -> Task<Result<UserAttachment>> {
156 let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
157 return Task::ready(Err(anyhow!("no attachment tool")));
158 };
159
160 (attachment.call)(cx)
161 }
162
163 pub fn call_all_attachment_tools(
164 self: Arc<Self>,
165 cx: &mut WindowContext<'_>,
166 ) -> Task<Result<Vec<UserAttachment>>> {
167 let this = self.clone();
168 cx.spawn(|mut cx| async move {
169 let attachment_tasks = cx.update(|cx| {
170 let mut tasks = Vec::new();
171 for attachment in this
172 .registered_attachments
173 .values()
174 .filter(|attachment| attachment.enabled.load(SeqCst))
175 {
176 tasks.push((attachment.call)(cx))
177 }
178
179 tasks
180 })?;
181
182 let attachments = join_all(attachment_tasks.into_iter()).await;
183
184 Ok(attachments
185 .into_iter()
186 .filter_map(|attachment| attachment.log_err())
187 .collect())
188 })
189 }
190
191 pub fn serialize_user_attachment(
192 &self,
193 user_attachment: &UserAttachment,
194 ) -> SavedUserAttachment {
195 SavedUserAttachment {
196 name: user_attachment.name.clone(),
197 serialized_output: user_attachment.serialized_output.clone(),
198 }
199 }
200
201 pub fn deserialize_user_attachment(
202 &self,
203 saved_user_attachment: SavedUserAttachment,
204 cx: &mut WindowContext,
205 ) -> Result<UserAttachment> {
206 if let Some(registered_attachment) = self
207 .registered_attachments
208 .values()
209 .find(|attachment| attachment.name == saved_user_attachment.name)
210 {
211 (registered_attachment.deserialize)(&saved_user_attachment, cx)
212 } else {
213 Err(anyhow!(
214 "no attachment tool for name {}",
215 saved_user_attachment.name
216 ))
217 }
218 }
219}
220
221impl UserAttachment {
222 pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
223 let result = (self.generate_fn)(self.view.clone(), output, cx);
224 if result.is_empty() {
225 None
226 } else {
227 Some(result)
228 }
229 }
230}