attachment_registry.rs

  1use crate::{ProjectContext, ToolOutput};
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use futures::future::join_all;
  5use gpui::{AnyView, Render, Task, View, WindowContext};
  6use serde::{de::DeserializeOwned, Deserialize, Serialize};
  7use serde_json::value::RawValue;
  8use std::{
  9    any::TypeId,
 10    sync::{
 11        atomic::{AtomicBool, Ordering::SeqCst},
 12        Arc,
 13    },
 14};
 15use util::ResultExt as _;
 16
 17pub struct AttachmentRegistry {
 18    registered_attachments: HashMap<TypeId, RegisteredAttachment>,
 19}
 20
 21pub trait LanguageModelAttachment {
 22    type Output: DeserializeOwned + Serialize + 'static;
 23    type View: Render + ToolOutput;
 24
 25    fn name(&self) -> Arc<str>;
 26    fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 27    fn view(&self, output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
 28}
 29
 30/// A collected attachment from running an attachment tool
 31pub struct UserAttachment {
 32    pub view: AnyView,
 33    name: Arc<str>,
 34    serialized_output: Result<Box<RawValue>, String>,
 35    generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
 36}
 37
 38#[derive(Serialize, Deserialize)]
 39pub struct SavedUserAttachment {
 40    name: Arc<str>,
 41    serialized_output: Result<Box<RawValue>, String>,
 42}
 43
 44/// Internal representation of an attachment tool to allow us to treat them dynamically
 45struct RegisteredAttachment {
 46    name: Arc<str>,
 47    enabled: AtomicBool,
 48    call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
 49    deserialize: Box<dyn Fn(&SavedUserAttachment, &mut WindowContext) -> Result<UserAttachment>>,
 50}
 51
 52impl AttachmentRegistry {
 53    pub fn new() -> Self {
 54        Self {
 55            registered_attachments: HashMap::default(),
 56        }
 57    }
 58
 59    pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
 60        let attachment = Arc::new(attachment);
 61
 62        let call = Box::new({
 63            let attachment = attachment.clone();
 64            move |cx: &mut WindowContext| {
 65                let result = attachment.run(cx);
 66                let attachment = attachment.clone();
 67                cx.spawn(move |mut cx| async move {
 68                    let result: Result<A::Output> = result.await;
 69                    let serialized_output =
 70                        result
 71                            .as_ref()
 72                            .map_err(ToString::to_string)
 73                            .and_then(|output| {
 74                                Ok(RawValue::from_string(
 75                                    serde_json::to_string(output).map_err(|e| e.to_string())?,
 76                                )
 77                                .unwrap())
 78                            });
 79
 80                    let view = cx.update(|cx| attachment.view(result, cx))?;
 81
 82                    Ok(UserAttachment {
 83                        name: attachment.name(),
 84                        view: view.into(),
 85                        generate_fn: generate::<A>,
 86                        serialized_output,
 87                    })
 88                })
 89            }
 90        });
 91
 92        let deserialize = Box::new({
 93            let attachment = attachment.clone();
 94            move |saved_attachment: &SavedUserAttachment, cx: &mut WindowContext| {
 95                let serialized_output = saved_attachment.serialized_output.clone();
 96                let output = match &serialized_output {
 97                    Ok(serialized_output) => {
 98                        Ok(serde_json::from_str::<A::Output>(serialized_output.get())?)
 99                    }
100                    Err(error) => Err(anyhow!("{error}")),
101                };
102                let view = attachment.view(output, cx).into();
103
104                Ok(UserAttachment {
105                    name: saved_attachment.name.clone(),
106                    view,
107                    serialized_output,
108                    generate_fn: generate::<A>,
109                })
110            }
111        });
112
113        self.registered_attachments.insert(
114            TypeId::of::<A>(),
115            RegisteredAttachment {
116                name: attachment.name(),
117                call,
118                deserialize,
119                enabled: AtomicBool::new(true),
120            },
121        );
122        return;
123
124        fn generate<T: LanguageModelAttachment>(
125            view: AnyView,
126            project: &mut ProjectContext,
127            cx: &mut WindowContext,
128        ) -> String {
129            view.downcast::<T::View>()
130                .unwrap()
131                .update(cx, |view, cx| T::View::generate(view, project, cx))
132        }
133    }
134
135    pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
136        &self,
137        is_enabled: bool,
138    ) {
139        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
140            attachment.enabled.store(is_enabled, SeqCst);
141        }
142    }
143
144    pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
145        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
146            attachment.enabled.load(SeqCst)
147        } else {
148            false
149        }
150    }
151
152    pub fn call<A: LanguageModelAttachment + 'static>(
153        &self,
154        cx: &mut WindowContext,
155    ) -> Task<Result<UserAttachment>> {
156        let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
157            return Task::ready(Err(anyhow!("no attachment tool")));
158        };
159
160        (attachment.call)(cx)
161    }
162
163    pub fn call_all_attachment_tools(
164        self: Arc<Self>,
165        cx: &mut WindowContext<'_>,
166    ) -> Task<Result<Vec<UserAttachment>>> {
167        let this = self.clone();
168        cx.spawn(|mut cx| async move {
169            let attachment_tasks = cx.update(|cx| {
170                let mut tasks = Vec::new();
171                for attachment in this
172                    .registered_attachments
173                    .values()
174                    .filter(|attachment| attachment.enabled.load(SeqCst))
175                {
176                    tasks.push((attachment.call)(cx))
177                }
178
179                tasks
180            })?;
181
182            let attachments = join_all(attachment_tasks.into_iter()).await;
183
184            Ok(attachments
185                .into_iter()
186                .filter_map(|attachment| attachment.log_err())
187                .collect())
188        })
189    }
190
191    pub fn serialize_user_attachment(
192        &self,
193        user_attachment: &UserAttachment,
194    ) -> SavedUserAttachment {
195        SavedUserAttachment {
196            name: user_attachment.name.clone(),
197            serialized_output: user_attachment.serialized_output.clone(),
198        }
199    }
200
201    pub fn deserialize_user_attachment(
202        &self,
203        saved_user_attachment: SavedUserAttachment,
204        cx: &mut WindowContext,
205    ) -> Result<UserAttachment> {
206        if let Some(registered_attachment) = self
207            .registered_attachments
208            .values()
209            .find(|attachment| attachment.name == saved_user_attachment.name)
210        {
211            (registered_attachment.deserialize)(&saved_user_attachment, cx)
212        } else {
213            Err(anyhow!(
214                "no attachment tool for name {}",
215                saved_user_attachment.name
216            ))
217        }
218    }
219}
220
221impl UserAttachment {
222    pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
223        let result = (self.generate_fn)(self.view.clone(), output, cx);
224        if result.is_empty() {
225            None
226        } else {
227            Some(result)
228        }
229    }
230}