1use crate::ProjectContext;
2use anyhow::{anyhow, Result};
3use collections::HashMap;
4use futures::future::join_all;
5use gpui::{AnyView, Render, Task, View, WindowContext};
6use serde::{de::DeserializeOwned, Deserialize, Serialize};
7use serde_json::value::RawValue;
8use std::{
9 any::TypeId,
10 sync::{
11 atomic::{AtomicBool, Ordering::SeqCst},
12 Arc,
13 },
14};
15use util::ResultExt as _;
16
17pub struct AttachmentRegistry {
18 registered_attachments: HashMap<TypeId, RegisteredAttachment>,
19}
20
21pub trait AttachmentOutput {
22 fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
23}
24
25pub trait LanguageModelAttachment {
26 type Output: DeserializeOwned + Serialize + 'static;
27 type View: Render + AttachmentOutput;
28
29 fn name(&self) -> Arc<str>;
30 fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
31 fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
32}
33
34/// A collected attachment from running an attachment tool
35pub struct UserAttachment {
36 pub view: AnyView,
37 name: Arc<str>,
38 serialized_output: Result<Box<RawValue>, String>,
39 generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
40}
41
42#[derive(Serialize, Deserialize)]
43pub struct SavedUserAttachment {
44 name: Arc<str>,
45 serialized_output: Result<Box<RawValue>, String>,
46}
47
48/// Internal representation of an attachment tool to allow us to treat them dynamically
49struct RegisteredAttachment {
50 name: Arc<str>,
51 enabled: AtomicBool,
52 call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
53 deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
54}
55
56impl AttachmentRegistry {
57 pub fn new() -> Self {
58 Self {
59 registered_attachments: HashMap::default(),
60 }
61 }
62
63 pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
64 let attachment = Arc::new(attachment);
65
66 let call = Box::new({
67 let attachment = attachment.clone();
68 move |cx: &mut WindowContext| {
69 let result = attachment.run(cx);
70 let attachment = attachment.clone();
71 cx.spawn(move |mut cx| async move {
72 let result: Result<A::Output> = result.await;
73 let serialized_output =
74 result
75 .as_ref()
76 .map_err(ToString::to_string)
77 .and_then(|output| {
78 Ok(RawValue::from_string(
79 serde_json::to_string(output).map_err(|e| e.to_string())?,
80 )
81 .unwrap())
82 });
83
84 let view = cx.update(|cx| attachment.view(result, cx))?;
85
86 Ok(UserAttachment {
87 name: attachment.name(),
88 view: view.into(),
89 generate_fn: generate::<A>,
90 serialized_output,
91 })
92 })
93 }
94 });
95
96 let deserialize = Box::new({
97 let attachment = attachment.clone();
98 move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
99 let serialized_output = saved_attachment.serialized_output.clone();
100 let output = match &serialized_output {
101 Ok(serialized_output) => {
102 Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
103 }
104 Err(error) => Err(anyhow!("{error}")),
105 };
106 let view = attachment.view(output, cx).into();
107
108 Ok(UserAttachment {
109 name: saved_attachment.name.clone(),
110 view,
111 serialized_output,
112 generate_fn: generate::<A>,
113 })
114 }
115 });
116
117 self.registered_attachments.insert(
118 TypeId::of::<A>(),
119 RegisteredAttachment {
120 name: attachment.name(),
121 call,
122 deserialize,
123 enabled: AtomicBool::new(true),
124 },
125 );
126 return;
127
128 fn generate<T: LanguageModelAttachment>(
129 view: AnyView,
130 project: &mut ProjectContext,
131 cx: &mut WindowContext,
132 ) -> String {
133 view.downcast::<T::View>()
134 .unwrap()
135 .update(cx, |view, cx| T::View::generate(view, project, cx))
136 }
137 }
138
139 pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
140 &self,
141 is_enabled: bool,
142 ) {
143 if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
144 attachment.enabled.store(is_enabled, SeqCst);
145 }
146 }
147
148 pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
149 if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
150 attachment.enabled.load(SeqCst)
151 } else {
152 false
153 }
154 }
155
156 pub fn call<A: LanguageModelAttachment + 'static>(
157 &self,
158 cx: &mut WindowContext,
159 ) -> Task<Result<UserAttachment>> {
160 let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
161 return Task::ready(Err(anyhow!("no attachment tool")));
162 };
163
164 (attachment.call)(cx)
165 }
166
167 pub fn call_all_attachment_tools(
168 self: Arc<Self>,
169 cx: &mut WindowContext<'_>,
170 ) -> Task<Result<Vec<UserAttachment>>> {
171 let this = self.clone();
172 cx.spawn(|mut cx| async move {
173 let attachment_tasks = cx.update(|cx| {
174 let mut tasks = Vec::new();
175 for attachment in this
176 .registered_attachments
177 .values()
178 .filter(|attachment| attachment.enabled.load(SeqCst))
179 {
180 tasks.push((attachment.call)(cx))
181 }
182
183 tasks
184 })?;
185
186 let attachments = join_all(attachment_tasks.into_iter()).await;
187
188 Ok(attachments
189 .into_iter()
190 .filter_map(|attachment| attachment.log_err())
191 .collect())
192 })
193 }
194
195 pub fn serialize_user_attachment(
196 &self,
197 user_attachment: &UserAttachment,
198 ) -> SavedUserAttachment {
199 SavedUserAttachment {
200 name: user_attachment.name.clone(),
201 serialized_output: user_attachment.serialized_output.clone(),
202 }
203 }
204
205 pub fn deserialize_user_attachment(
206 &self,
207 saved_user_attachment: SavedUserAttachment,
208 cx: &mut WindowContext,
209 ) -> Result<UserAttachment> {
210 if let Some(registered_attachment) = self
211 .registered_attachments
212 .values()
213 .find(|attachment| attachment.name == saved_user_attachment.name)
214 {
215 (registered_attachment.deserialize)(&saved_user_attachment, cx)
216 } else {
217 Err(anyhow!(
218 "no attachment tool for name {}",
219 saved_user_attachment.name
220 ))
221 }
222 }
223}
224
225impl UserAttachment {
226 pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
227 let result = (self.generate_fn)(self.view.clone(), output, cx);
228 if result.is_empty() {
229 None
230 } else {
231 Some(result)
232 }
233 }
234}