attachment_registry.rs

  1use crate::ProjectContext;
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use futures::future::join_all;
  5use gpui::{AnyView, Render, Task, View, WindowContext};
  6use serde::{de::DeserializeOwned, Deserialize, Serialize};
  7use serde_json::value::RawValue;
  8use std::{
  9    any::TypeId,
 10    sync::{
 11        atomic::{AtomicBool, Ordering::SeqCst},
 12        Arc,
 13    },
 14};
 15use util::ResultExt as _;
 16
 17pub struct AttachmentRegistry {
 18    registered_attachments: HashMap<TypeId, RegisteredAttachment>,
 19}
 20
 21pub trait AttachmentOutput {
 22    fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
 23}
 24
 25pub trait LanguageModelAttachment {
 26    type Output: DeserializeOwned + Serialize + 'static;
 27    type View: Render + AttachmentOutput;
 28
 29    fn name(&self) -> Arc<str>;
 30    fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 31    fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
 32}
 33
 34/// A collected attachment from running an attachment tool
 35pub struct UserAttachment {
 36    pub view: AnyView,
 37    name: Arc<str>,
 38    serialized_output: Result<Box<RawValue>, String>,
 39    generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
 40}
 41
 42#[derive(Serialize, Deserialize)]
 43pub struct SavedUserAttachment {
 44    name: Arc<str>,
 45    serialized_output: Result<Box<RawValue>, String>,
 46}
 47
 48/// Internal representation of an attachment tool to allow us to treat them dynamically
 49struct RegisteredAttachment {
 50    name: Arc<str>,
 51    enabled: AtomicBool,
 52    call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
 53    deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
 54}
 55
 56impl AttachmentRegistry {
 57    pub fn new() -> Self {
 58        Self {
 59            registered_attachments: HashMap::default(),
 60        }
 61    }
 62
 63    pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
 64        let attachment = Arc::new(attachment);
 65
 66        let call = Box::new({
 67            let attachment = attachment.clone();
 68            move |cx: &mut WindowContext| {
 69                let result = attachment.run(cx);
 70                let attachment = attachment.clone();
 71                cx.spawn(move |mut cx| async move {
 72                    let result: Result<A::Output> = result.await;
 73                    let serialized_output =
 74                        result
 75                            .as_ref()
 76                            .map_err(ToString::to_string)
 77                            .and_then(|output| {
 78                                Ok(RawValue::from_string(
 79                                    serde_json::to_string(output).map_err(|e| e.to_string())?,
 80                                )
 81                                .unwrap())
 82                            });
 83
 84                    let view = cx.update(|cx| attachment.view(result, cx))?;
 85
 86                    Ok(UserAttachment {
 87                        name: attachment.name(),
 88                        view: view.into(),
 89                        generate_fn: generate::<A>,
 90                        serialized_output,
 91                    })
 92                })
 93            }
 94        });
 95
 96        let deserialize = Box::new({
 97            let attachment = attachment.clone();
 98            move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
 99                let serialized_output = saved_attachment.serialized_output.clone();
100                let output = match &serialized_output {
101                    Ok(serialized_output) => {
102                        Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
103                    }
104                    Err(error) => Err(anyhow!("{error}")),
105                };
106                let view = attachment.view(output, cx).into();
107
108                Ok(UserAttachment {
109                    name: saved_attachment.name.clone(),
110                    view,
111                    serialized_output,
112                    generate_fn: generate::<A>,
113                })
114            }
115        });
116
117        self.registered_attachments.insert(
118            TypeId::of::<A>(),
119            RegisteredAttachment {
120                name: attachment.name(),
121                call,
122                deserialize,
123                enabled: AtomicBool::new(true),
124            },
125        );
126        return;
127
128        fn generate<T: LanguageModelAttachment>(
129            view: AnyView,
130            project: &mut ProjectContext,
131            cx: &mut WindowContext,
132        ) -> String {
133            view.downcast::<T::View>()
134                .unwrap()
135                .update(cx, |view, cx| T::View::generate(view, project, cx))
136        }
137    }
138
139    pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
140        &self,
141        is_enabled: bool,
142    ) {
143        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
144            attachment.enabled.store(is_enabled, SeqCst);
145        }
146    }
147
148    pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
149        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
150            attachment.enabled.load(SeqCst)
151        } else {
152            false
153        }
154    }
155
156    pub fn call<A: LanguageModelAttachment + 'static>(
157        &self,
158        cx: &mut WindowContext,
159    ) -> Task<Result<UserAttachment>> {
160        let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
161            return Task::ready(Err(anyhow!("no attachment tool")));
162        };
163
164        (attachment.call)(cx)
165    }
166
167    pub fn call_all_attachment_tools(
168        self: Arc<Self>,
169        cx: &mut WindowContext<'_>,
170    ) -> Task<Result<Vec<UserAttachment>>> {
171        let this = self.clone();
172        cx.spawn(|mut cx| async move {
173            let attachment_tasks = cx.update(|cx| {
174                let mut tasks = Vec::new();
175                for attachment in this
176                    .registered_attachments
177                    .values()
178                    .filter(|attachment| attachment.enabled.load(SeqCst))
179                {
180                    tasks.push((attachment.call)(cx))
181                }
182
183                tasks
184            })?;
185
186            let attachments = join_all(attachment_tasks.into_iter()).await;
187
188            Ok(attachments
189                .into_iter()
190                .filter_map(|attachment| attachment.log_err())
191                .collect())
192        })
193    }
194
195    pub fn serialize_user_attachment(
196        &self,
197        user_attachment: &UserAttachment,
198    ) -> SavedUserAttachment {
199        SavedUserAttachment {
200            name: user_attachment.name.clone(),
201            serialized_output: user_attachment.serialized_output.clone(),
202        }
203    }
204
205    pub fn deserialize_user_attachment(
206        &self,
207        saved_user_attachment: SavedUserAttachment,
208        cx: &mut WindowContext,
209    ) -> Result<UserAttachment> {
210        if let Some(registered_attachment) = self
211            .registered_attachments
212            .values()
213            .find(|attachment| attachment.name == saved_user_attachment.name)
214        {
215            (registered_attachment.deserialize)(&saved_user_attachment, cx)
216        } else {
217            Err(anyhow!(
218                "no attachment tool for name {}",
219                saved_user_attachment.name
220            ))
221        }
222    }
223}
224
225impl UserAttachment {
226    pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
227        let result = (self.generate_fn)(self.view.clone(), output, cx);
228        if result.is_empty() {
229            None
230        } else {
231            Some(result)
232        }
233    }
234}