1use std::{ops::Range, path::Path, sync::Arc};
2
3use gpui::{App, Entity, SharedString};
4use language::{Buffer, File};
5use language_model::LanguageModelRequestMessage;
6use project::{ProjectPath, Worktree};
7use prompt_store::UserPromptId;
8use rope::Point;
9use serde::{Deserialize, Serialize};
10use text::{Anchor, BufferId};
11use ui::IconName;
12use util::post_inc;
13
14use crate::thread::Thread;
15
16pub const RULES_ICON: IconName = IconName::Context;
17
18#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
19pub struct ContextId(pub(crate) usize);
20
21impl ContextId {
22 pub fn post_inc(&mut self) -> Self {
23 Self(post_inc(&mut self.0))
24 }
25}
26
27pub enum ContextKind {
28 File,
29 Directory,
30 Symbol,
31 Excerpt,
32 FetchedUrl,
33 Thread,
34 Rules,
35}
36
37impl ContextKind {
38 pub fn icon(&self) -> IconName {
39 match self {
40 ContextKind::File => IconName::File,
41 ContextKind::Directory => IconName::Folder,
42 ContextKind::Symbol => IconName::Code,
43 ContextKind::Excerpt => IconName::Code,
44 ContextKind::FetchedUrl => IconName::Globe,
45 ContextKind::Thread => IconName::MessageBubbles,
46 ContextKind::Rules => RULES_ICON,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum AssistantContext {
53 File(FileContext),
54 Directory(DirectoryContext),
55 Symbol(SymbolContext),
56 FetchedUrl(FetchedUrlContext),
57 Thread(ThreadContext),
58 Excerpt(ExcerptContext),
59 Rules(RulesContext),
60}
61
62impl AssistantContext {
63 pub fn id(&self) -> ContextId {
64 match self {
65 Self::File(file) => file.id,
66 Self::Directory(directory) => directory.id,
67 Self::Symbol(symbol) => symbol.id,
68 Self::FetchedUrl(url) => url.id,
69 Self::Thread(thread) => thread.id,
70 Self::Excerpt(excerpt) => excerpt.id,
71 Self::Rules(rules) => rules.id,
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
77pub struct FileContext {
78 pub id: ContextId,
79 pub context_buffer: ContextBuffer,
80}
81
82#[derive(Debug, Clone)]
83pub struct DirectoryContext {
84 pub id: ContextId,
85 pub worktree: Entity<Worktree>,
86 pub path: Arc<Path>,
87 /// Buffers of the files within the directory.
88 pub context_buffers: Vec<ContextBuffer>,
89}
90
91impl DirectoryContext {
92 pub fn project_path(&self, cx: &App) -> ProjectPath {
93 ProjectPath {
94 worktree_id: self.worktree.read(cx).id(),
95 path: self.path.clone(),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
101pub struct SymbolContext {
102 pub id: ContextId,
103 pub context_symbol: ContextSymbol,
104}
105
106#[derive(Debug, Clone)]
107pub struct FetchedUrlContext {
108 pub id: ContextId,
109 pub url: SharedString,
110 pub text: SharedString,
111}
112
113#[derive(Debug, Clone)]
114pub struct ThreadContext {
115 pub id: ContextId,
116 // TODO: Entity<Thread> holds onto the thread even if the thread is deleted. Should probably be
117 // a WeakEntity and handle removal from the UI when it has dropped.
118 pub thread: Entity<Thread>,
119 pub text: SharedString,
120}
121
122impl ThreadContext {
123 pub fn summary(&self, cx: &App) -> SharedString {
124 self.thread
125 .read(cx)
126 .summary()
127 .unwrap_or("New thread".into())
128 }
129}
130
131#[derive(Clone)]
132pub struct ContextBuffer {
133 pub id: BufferId,
134 // TODO: Entity<Buffer> holds onto the thread even if the thread is deleted. Should probably be
135 // a WeakEntity and handle removal from the UI when it has dropped.
136 pub buffer: Entity<Buffer>,
137 pub file: Arc<dyn File>,
138 pub version: clock::Global,
139 pub text: SharedString,
140}
141
142impl std::fmt::Debug for ContextBuffer {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 f.debug_struct("ContextBuffer")
145 .field("id", &self.id)
146 .field("buffer", &self.buffer)
147 .field("version", &self.version)
148 .field("text", &self.text)
149 .finish()
150 }
151}
152
153#[derive(Debug, Clone)]
154pub struct ContextSymbol {
155 pub id: ContextSymbolId,
156 pub buffer: Entity<Buffer>,
157 pub buffer_version: clock::Global,
158 /// The range that the symbol encloses, e.g. for function symbol, this will
159 /// include not only the signature, but also the body
160 pub enclosing_range: Range<Anchor>,
161 pub text: SharedString,
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Hash)]
165pub struct ContextSymbolId {
166 pub path: ProjectPath,
167 pub name: SharedString,
168 pub range: Range<Anchor>,
169}
170
171#[derive(Debug, Clone)]
172pub struct ExcerptContext {
173 pub id: ContextId,
174 pub range: Range<Anchor>,
175 pub line_range: Range<Point>,
176 pub context_buffer: ContextBuffer,
177}
178
179#[derive(Debug, Clone)]
180pub struct RulesContext {
181 pub id: ContextId,
182 pub prompt_id: UserPromptId,
183 pub title: SharedString,
184 pub text: SharedString,
185}
186
187/// Formats a collection of contexts into a string representation
188pub fn format_context_as_string<'a>(
189 contexts: impl Iterator<Item = &'a AssistantContext>,
190 cx: &App,
191) -> Option<String> {
192 let mut file_context = Vec::new();
193 let mut directory_context = Vec::new();
194 let mut symbol_context = Vec::new();
195 let mut excerpt_context = Vec::new();
196 let mut fetch_context = Vec::new();
197 let mut thread_context = Vec::new();
198 let mut rules_context = Vec::new();
199
200 for context in contexts {
201 match context {
202 AssistantContext::File(context) => file_context.push(context),
203 AssistantContext::Directory(context) => directory_context.push(context),
204 AssistantContext::Symbol(context) => symbol_context.push(context),
205 AssistantContext::Excerpt(context) => excerpt_context.push(context),
206 AssistantContext::FetchedUrl(context) => fetch_context.push(context),
207 AssistantContext::Thread(context) => thread_context.push(context),
208 AssistantContext::Rules(context) => rules_context.push(context),
209 }
210 }
211
212 if file_context.is_empty()
213 && directory_context.is_empty()
214 && symbol_context.is_empty()
215 && excerpt_context.is_empty()
216 && fetch_context.is_empty()
217 && thread_context.is_empty()
218 && rules_context.is_empty()
219 {
220 return None;
221 }
222
223 let mut result = String::new();
224 result.push_str("\n<context>\n\
225 The following items were attached by the user. You don't need to use other tools to read them.\n\n");
226
227 if !file_context.is_empty() {
228 result.push_str("<files>\n");
229 for context in file_context {
230 result.push_str(&context.context_buffer.text);
231 }
232 result.push_str("</files>\n");
233 }
234
235 if !directory_context.is_empty() {
236 result.push_str("<directories>\n");
237 for context in directory_context {
238 for context_buffer in &context.context_buffers {
239 result.push_str(&context_buffer.text);
240 }
241 }
242 result.push_str("</directories>\n");
243 }
244
245 if !symbol_context.is_empty() {
246 result.push_str("<symbols>\n");
247 for context in symbol_context {
248 result.push_str(&context.context_symbol.text);
249 result.push('\n');
250 }
251 result.push_str("</symbols>\n");
252 }
253
254 if !excerpt_context.is_empty() {
255 result.push_str("<excerpts>\n");
256 for context in excerpt_context {
257 result.push_str(&context.context_buffer.text);
258 result.push('\n');
259 }
260 result.push_str("</excerpts>\n");
261 }
262
263 if !fetch_context.is_empty() {
264 result.push_str("<fetched_urls>\n");
265 for context in &fetch_context {
266 result.push_str(&context.url);
267 result.push('\n');
268 result.push_str(&context.text);
269 result.push('\n');
270 }
271 result.push_str("</fetched_urls>\n");
272 }
273
274 if !thread_context.is_empty() {
275 result.push_str("<conversation_threads>\n");
276 for context in &thread_context {
277 result.push_str(&context.summary(cx));
278 result.push('\n');
279 result.push_str(&context.text);
280 result.push('\n');
281 }
282 result.push_str("</conversation_threads>\n");
283 }
284
285 if !rules_context.is_empty() {
286 result.push_str(
287 "<user_rules>\n\
288 The user has specified the following rules that should be applied:\n\n",
289 );
290 for context in &rules_context {
291 result.push_str(&context.text);
292 result.push('\n');
293 }
294 result.push_str("</user_rules>\n");
295 }
296
297 result.push_str("</context>\n");
298 Some(result)
299}
300
301pub fn attach_context_to_message<'a>(
302 message: &mut LanguageModelRequestMessage,
303 contexts: impl Iterator<Item = &'a AssistantContext>,
304 cx: &App,
305) {
306 if let Some(context_string) = format_context_as_string(contexts, cx) {
307 message.content.push(context_string.into());
308 }
309}