1use crate::{ProjectContext, ToolOutput};
2use anyhow::{anyhow, Result};
3use collections::HashMap;
4use futures::future::join_all;
5use gpui::{AnyView, Render, Task, View, WindowContext};
6use std::{
7 any::TypeId,
8 sync::{
9 atomic::{AtomicBool, Ordering::SeqCst},
10 Arc,
11 },
12};
13use util::ResultExt as _;
14
15pub struct AttachmentRegistry {
16 registered_attachments: HashMap<TypeId, RegisteredAttachment>,
17}
18
19pub trait LanguageModelAttachment {
20 type Output: 'static;
21 type View: Render + ToolOutput;
22
23 fn run(&self, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
24
25 fn view(output: Result<Self::Output>, cx: &mut WindowContext) -> View<Self::View>;
26}
27
28/// A collected attachment from running an attachment tool
29pub struct UserAttachment {
30 pub view: AnyView,
31 generate_fn: fn(AnyView, &mut ProjectContext, cx: &mut WindowContext) -> String,
32}
33
34/// Internal representation of an attachment tool to allow us to treat them dynamically
35struct RegisteredAttachment {
36 enabled: AtomicBool,
37 call: Box<dyn Fn(&mut WindowContext) -> Task<Result<UserAttachment>>>,
38}
39
40impl AttachmentRegistry {
41 pub fn new() -> Self {
42 Self {
43 registered_attachments: HashMap::default(),
44 }
45 }
46
47 pub fn register<A: LanguageModelAttachment + 'static>(&mut self, attachment: A) {
48 let call = Box::new(move |cx: &mut WindowContext| {
49 let result = attachment.run(cx);
50
51 cx.spawn(move |mut cx| async move {
52 let result: Result<A::Output> = result.await;
53 let view = cx.update(|cx| A::view(result, cx))?;
54
55 Ok(UserAttachment {
56 view: view.into(),
57 generate_fn: generate::<A>,
58 })
59 })
60 });
61
62 self.registered_attachments.insert(
63 TypeId::of::<A>(),
64 RegisteredAttachment {
65 call,
66 enabled: AtomicBool::new(true),
67 },
68 );
69 return;
70
71 fn generate<T: LanguageModelAttachment>(
72 view: AnyView,
73 project: &mut ProjectContext,
74 cx: &mut WindowContext,
75 ) -> String {
76 view.downcast::<T::View>()
77 .unwrap()
78 .update(cx, |view, cx| T::View::generate(view, project, cx))
79 }
80 }
81
82 pub fn set_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(
83 &self,
84 is_enabled: bool,
85 ) {
86 if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
87 attachment.enabled.store(is_enabled, SeqCst);
88 }
89 }
90
91 pub fn is_attachment_tool_enabled<A: LanguageModelAttachment + 'static>(&self) -> bool {
92 if let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) {
93 attachment.enabled.load(SeqCst)
94 } else {
95 false
96 }
97 }
98
99 pub fn call<A: LanguageModelAttachment + 'static>(
100 &self,
101 cx: &mut WindowContext,
102 ) -> Task<Result<UserAttachment>> {
103 let Some(attachment) = self.registered_attachments.get(&TypeId::of::<A>()) else {
104 return Task::ready(Err(anyhow!("no attachment tool")));
105 };
106
107 (attachment.call)(cx)
108 }
109
110 pub fn call_all_attachment_tools(
111 self: Arc<Self>,
112 cx: &mut WindowContext<'_>,
113 ) -> Task<Result<Vec<UserAttachment>>> {
114 let this = self.clone();
115 cx.spawn(|mut cx| async move {
116 let attachment_tasks = cx.update(|cx| {
117 let mut tasks = Vec::new();
118 for attachment in this
119 .registered_attachments
120 .values()
121 .filter(|attachment| attachment.enabled.load(SeqCst))
122 {
123 tasks.push((attachment.call)(cx))
124 }
125
126 tasks
127 })?;
128
129 let attachments = join_all(attachment_tasks.into_iter()).await;
130
131 Ok(attachments
132 .into_iter()
133 .filter_map(|attachment| attachment.log_err())
134 .collect())
135 })
136 }
137}
138
139impl UserAttachment {
140 pub fn generate(&self, output: &mut ProjectContext, cx: &mut WindowContext) -> Option<String> {
141 let result = (self.generate_fn)(self.view.clone(), output, cx);
142 if result.is_empty() {
143 None
144 } else {
145 Some(result)
146 }
147 }
148}