attachment_registry.rs

  1use crate::{ProjectContext, ToolOutput};
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use futures::future::join_all;
  5use gpui::{AnyView, Render, Task, View, WindowContext};
  6use std::{
  7    any::TypeId,
  8    sync::{
  9        atomic::{AtomicBool, Ordering::SeqCst},
 10        Arc,
 11    },
 12};
 13use util::ResultExt as _;
 14
 15pub struct AttachmentRegistry {
 16    registered_attachments: HashMap<TypeId, RegisteredAttachment>,
 17}
 18
 19pub trait LanguageModelAttachment {
 20    type Output: 'static;
 21    type View: Render + ToolOutput;
 22
 23    fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 24
 25    fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
 26}
 27
 28/// A collected attachment from running an attachment tool
 29pub struct UserAttachment {
 30    pub view: AnyView,
 31    generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
 32}
 33
 34/// Internal representation of an attachment tool to allow us to treat them dynamically
 35struct RegisteredAttachment {
 36    enabled: AtomicBool,
 37    call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
 38}
 39
 40impl AttachmentRegistry {
 41    pub fn new() -> Self {
 42        Self {
 43            registered_attachments: HashMap::default(),
 44        }
 45    }
 46
 47    pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
 48        let call = Box::new(move |cx: &mut WindowContext| {
 49            let result = attachment.run(cx);
 50
 51            cx.spawn(move |mut cx| async move {
 52                let result: Result<A::Output> = result.await;
 53                let view = cx.update(|cx| A::view(result, cx))?;
 54
 55                Ok(UserAttachment {
 56                    view: view.into(),
 57                    generate_fn: generate::<A>,
 58                })
 59            })
 60        });
 61
 62        self.registered_attachments.insert(
 63            TypeId::of::<A>(),
 64            RegisteredAttachment {
 65                call,
 66                enabled: AtomicBool::new(true),
 67            },
 68        );
 69        return;
 70
 71        fn generate<T: LanguageModelAttachment>(
 72            view: AnyView,
 73            project: &mut ProjectContext,
 74            cx: &mut WindowContext,
 75        ) -> String {
 76            view.downcast::<T::View>()
 77                .unwrap()
 78                .update(cx, |view, cx| T::View::generate(view, project, cx))
 79        }
 80    }
 81
 82    pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
 83        &self,
 84        is_enabled: bool,
 85    ) {
 86        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
 87            attachment.enabled.store(is_enabled, SeqCst);
 88        }
 89    }
 90
 91    pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
 92        if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
 93            attachment.enabled.load(SeqCst)
 94        } else {
 95            false
 96        }
 97    }
 98
 99    pub fn call<A: LanguageModelAttachment + 'static>(
100        &self,
101        cx: &mut WindowContext,
102    ) -> Task<Result<UserAttachment>> {
103        let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
104            return Task::ready(Err(anyhow!("no attachment tool")));
105        };
106
107        (attachment.call)(cx)
108    }
109
110    pub fn call_all_attachment_tools(
111        self: Arc<Self>,
112        cx: &mut WindowContext<'_>,
113    ) -> Task<Result<Vec<UserAttachment>>> {
114        let this = self.clone();
115        cx.spawn(|mut cx| async move {
116            let attachment_tasks = cx.update(|cx| {
117                let mut tasks = Vec::new();
118                for attachment in this
119                    .registered_attachments
120                    .values()
121                    .filter(|attachment| attachment.enabled.load(SeqCst))
122                {
123                    tasks.push((attachment.call)(cx))
124                }
125
126                tasks
127            })?;
128
129            let attachments = join_all(attachment_tasks.into_iter()).await;
130
131            Ok(attachments
132                .into_iter()
133                .filter_map(|attachment| attachment.log_err())
134                .collect())
135        })
136    }
137}
138
139impl UserAttachment {
140    pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
141        let result = (self.generate_fn)(self.view.clone(), output, cx);
142        if result.is_empty() {
143            None
144        } else {
145            Some(result)
146        }
147    }
148}