1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::{self, FutureExt};
8use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
9use language::Buffer;
10use language_model::LanguageModelImage;
11use project::{Project, ProjectItem, ProjectPath, Symbol};
12use prompt_store::UserPromptId;
13use ref_cast::RefCast as _;
14use text::{Anchor, OffsetRangeExt};
15
16use crate::ThreadStore;
17use crate::context::{
18 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
19 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
20 SymbolContextHandle, ThreadContextHandle,
21};
22use crate::context_strip::SuggestedContext;
23use crate::thread::{Thread, ThreadId};
24
25pub struct ContextStore {
26 project: WeakEntity<Project>,
27 thread_store: Option<WeakEntity<ThreadStore>>,
28 next_context_id: ContextId,
29 context_set: IndexSet<AgentContextKey>,
30 context_thread_ids: HashSet<ThreadId>,
31}
32
33impl ContextStore {
34 pub fn new(
35 project: WeakEntity<Project>,
36 thread_store: Option<WeakEntity<ThreadStore>>,
37 ) -> Self {
38 Self {
39 project,
40 thread_store,
41 next_context_id: ContextId::zero(),
42 context_set: IndexSet::default(),
43 context_thread_ids: HashSet::default(),
44 }
45 }
46
47 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
48 self.context_set.iter().map(|entry| entry.as_ref())
49 }
50
51 pub fn clear(&mut self) {
52 self.context_set.clear();
53 self.context_thread_ids.clear();
54 }
55
56 pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
57 let existing_context = thread
58 .messages()
59 .flat_map(|message| {
60 message
61 .loaded_context
62 .contexts
63 .iter()
64 .map(|context| AgentContextKey(context.handle()))
65 })
66 .collect::<HashSet<_>>();
67 self.context_set
68 .iter()
69 .filter(|context| !existing_context.contains(context))
70 .map(|entry| entry.0.clone())
71 .collect::<Vec<_>>()
72 }
73
74 pub fn add_file_from_path(
75 &mut self,
76 project_path: ProjectPath,
77 remove_if_exists: bool,
78 cx: &mut Context<Self>,
79 ) -> Task<Result<()>> {
80 let Some(project) = self.project.upgrade() else {
81 return Task::ready(Err(anyhow!("failed to read project")));
82 };
83
84 cx.spawn(async move |this, cx| {
85 let open_buffer_task = project.update(cx, |project, cx| {
86 project.open_buffer(project_path.clone(), cx)
87 })?;
88 let buffer = open_buffer_task.await?;
89 this.update(cx, |this, cx| {
90 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
91 })
92 })
93 }
94
95 pub fn add_file_from_buffer(
96 &mut self,
97 project_path: &ProjectPath,
98 buffer: Entity<Buffer>,
99 remove_if_exists: bool,
100 cx: &mut Context<Self>,
101 ) {
102 let context_id = self.next_context_id.post_inc();
103 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
104
105 let already_included = if self.has_context(&context) {
106 if remove_if_exists {
107 self.remove_context(&context, cx);
108 }
109 true
110 } else {
111 self.path_included_in_directory(project_path, cx).is_some()
112 };
113
114 if !already_included {
115 self.insert_context(context, cx);
116 }
117 }
118
119 pub fn add_directory(
120 &mut self,
121 project_path: &ProjectPath,
122 remove_if_exists: bool,
123 cx: &mut Context<Self>,
124 ) -> Result<()> {
125 let Some(project) = self.project.upgrade() else {
126 return Err(anyhow!("failed to read project"));
127 };
128
129 let Some(entry_id) = project
130 .read(cx)
131 .entry_for_path(project_path, cx)
132 .map(|entry| entry.id)
133 else {
134 return Err(anyhow!("no entry found for directory context"));
135 };
136
137 let context_id = self.next_context_id.post_inc();
138 let context = AgentContextHandle::Directory(DirectoryContextHandle {
139 entry_id,
140 context_id,
141 });
142
143 if self.has_context(&context) {
144 if remove_if_exists {
145 self.remove_context(&context, cx);
146 }
147 } else if self.path_included_in_directory(project_path, cx).is_none() {
148 self.insert_context(context, cx);
149 }
150
151 anyhow::Ok(())
152 }
153
154 pub fn add_symbol(
155 &mut self,
156 buffer: Entity<Buffer>,
157 symbol: SharedString,
158 range: Range<Anchor>,
159 enclosing_range: Range<Anchor>,
160 remove_if_exists: bool,
161 cx: &mut Context<Self>,
162 ) -> bool {
163 let context_id = self.next_context_id.post_inc();
164 let context = AgentContextHandle::Symbol(SymbolContextHandle {
165 buffer,
166 symbol,
167 range,
168 enclosing_range,
169 context_id,
170 });
171
172 if self.has_context(&context) {
173 if remove_if_exists {
174 self.remove_context(&context, cx);
175 }
176 return false;
177 }
178
179 self.insert_context(context, cx)
180 }
181
182 pub fn add_thread(
183 &mut self,
184 thread: Entity<Thread>,
185 remove_if_exists: bool,
186 cx: &mut Context<Self>,
187 ) {
188 let context_id = self.next_context_id.post_inc();
189 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
190
191 if self.has_context(&context) {
192 if remove_if_exists {
193 self.remove_context(&context, cx);
194 }
195 } else {
196 self.insert_context(context, cx);
197 }
198 }
199
200 pub fn add_rules(
201 &mut self,
202 prompt_id: UserPromptId,
203 remove_if_exists: bool,
204 cx: &mut Context<ContextStore>,
205 ) {
206 let context_id = self.next_context_id.post_inc();
207 let context = AgentContextHandle::Rules(RulesContextHandle {
208 prompt_id,
209 context_id,
210 });
211
212 if self.has_context(&context) {
213 if remove_if_exists {
214 self.remove_context(&context, cx);
215 }
216 } else {
217 self.insert_context(context, cx);
218 }
219 }
220
221 pub fn add_fetched_url(
222 &mut self,
223 url: String,
224 text: impl Into<SharedString>,
225 cx: &mut Context<ContextStore>,
226 ) {
227 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
228 url: url.into(),
229 text: text.into(),
230 context_id: self.next_context_id.post_inc(),
231 });
232
233 self.insert_context(context, cx);
234 }
235
236 pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
237 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
238 let context = AgentContextHandle::Image(ImageContext {
239 original_image: image,
240 image_task,
241 context_id: self.next_context_id.post_inc(),
242 });
243 self.insert_context(context, cx);
244 }
245
246 pub fn add_selection(
247 &mut self,
248 buffer: Entity<Buffer>,
249 range: Range<Anchor>,
250 cx: &mut Context<ContextStore>,
251 ) {
252 let context_id = self.next_context_id.post_inc();
253 let context = AgentContextHandle::Selection(SelectionContextHandle {
254 buffer,
255 range,
256 context_id,
257 });
258 self.insert_context(context, cx);
259 }
260
261 pub fn add_suggested_context(
262 &mut self,
263 suggested: &SuggestedContext,
264 cx: &mut Context<ContextStore>,
265 ) {
266 match suggested {
267 SuggestedContext::File {
268 buffer,
269 icon_path: _,
270 name: _,
271 } => {
272 if let Some(buffer) = buffer.upgrade() {
273 let context_id = self.next_context_id.post_inc();
274 self.insert_context(
275 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
276 cx,
277 );
278 };
279 }
280 SuggestedContext::Thread { thread, name: _ } => {
281 if let Some(thread) = thread.upgrade() {
282 let context_id = self.next_context_id.post_inc();
283 self.insert_context(
284 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
285 cx,
286 );
287 }
288 }
289 }
290 }
291
292 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
293 match &context {
294 AgentContextHandle::Thread(thread_context) => {
295 if let Some(thread_store) = self.thread_store.clone() {
296 thread_context.thread.update(cx, |thread, cx| {
297 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
298 });
299 self.context_thread_ids
300 .insert(thread_context.thread.read(cx).id().clone());
301 } else {
302 return false;
303 }
304 }
305 _ => {}
306 }
307 let inserted = self.context_set.insert(AgentContextKey(context));
308 if inserted {
309 cx.notify();
310 }
311 inserted
312 }
313
314 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
315 if self
316 .context_set
317 .shift_remove(AgentContextKey::ref_cast(context))
318 {
319 match context {
320 AgentContextHandle::Thread(thread_context) => {
321 self.context_thread_ids
322 .remove(thread_context.thread.read(cx).id());
323 }
324 _ => {}
325 }
326 cx.notify();
327 }
328 }
329
330 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
331 self.context_set
332 .contains(AgentContextKey::ref_cast(context))
333 }
334
335 /// Returns whether this file path is already included directly in the context, or if it will be
336 /// included in the context via a directory.
337 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
338 let project = self.project.upgrade()?.read(cx);
339 self.context().find_map(|context| match context {
340 AgentContextHandle::File(file_context) => {
341 FileInclusion::check_file(file_context, path, cx)
342 }
343 AgentContextHandle::Directory(directory_context) => {
344 FileInclusion::check_directory(directory_context, path, project, cx)
345 }
346 _ => None,
347 })
348 }
349
350 pub fn path_included_in_directory(
351 &self,
352 path: &ProjectPath,
353 cx: &App,
354 ) -> Option<FileInclusion> {
355 let project = self.project.upgrade()?.read(cx);
356 self.context().find_map(|context| match context {
357 AgentContextHandle::Directory(directory_context) => {
358 FileInclusion::check_directory(directory_context, path, project, cx)
359 }
360 _ => None,
361 })
362 }
363
364 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
365 self.context().any(|context| match context {
366 AgentContextHandle::Symbol(context) => {
367 if context.symbol != symbol.name {
368 return false;
369 }
370 let buffer = context.buffer.read(cx);
371 let Some(context_path) = buffer.project_path(cx) else {
372 return false;
373 };
374 if context_path != symbol.path {
375 return false;
376 }
377 let context_range = context.range.to_point_utf16(&buffer.snapshot());
378 context_range.start == symbol.range.start.0
379 && context_range.end == symbol.range.end.0
380 }
381 _ => false,
382 })
383 }
384
385 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
386 self.context_thread_ids.contains(thread_id)
387 }
388
389 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
390 self.context_set
391 .contains(&RulesContextHandle::lookup_key(prompt_id))
392 }
393
394 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
395 self.context_set
396 .contains(&FetchedUrlContext::lookup_key(url.into()))
397 }
398
399 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
400 self.context()
401 .filter_map(|context| match context {
402 AgentContextHandle::File(file) => {
403 let buffer = file.buffer.read(cx);
404 buffer.project_path(cx)
405 }
406 AgentContextHandle::Directory(_)
407 | AgentContextHandle::Symbol(_)
408 | AgentContextHandle::Selection(_)
409 | AgentContextHandle::FetchedUrl(_)
410 | AgentContextHandle::Thread(_)
411 | AgentContextHandle::Rules(_)
412 | AgentContextHandle::Image(_) => None,
413 })
414 .collect()
415 }
416
417 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
418 &self.context_thread_ids
419 }
420}
421
422pub enum FileInclusion {
423 Direct,
424 InDirectory { full_path: PathBuf },
425}
426
427impl FileInclusion {
428 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
429 let file_path = file_context.buffer.read(cx).project_path(cx)?;
430 if path == &file_path {
431 Some(FileInclusion::Direct)
432 } else {
433 None
434 }
435 }
436
437 fn check_directory(
438 directory_context: &DirectoryContextHandle,
439 path: &ProjectPath,
440 project: &Project,
441 cx: &App,
442 ) -> Option<Self> {
443 let worktree = project
444 .worktree_for_entry(directory_context.entry_id, cx)?
445 .read(cx);
446 let entry = worktree.entry_for_id(directory_context.entry_id)?;
447 let directory_path = ProjectPath {
448 worktree_id: worktree.id(),
449 path: entry.path.clone(),
450 };
451 if path.starts_with(&directory_path) {
452 if path == &directory_path {
453 Some(FileInclusion::Direct)
454 } else {
455 Some(FileInclusion::InDirectory {
456 full_path: worktree.full_path(&entry.path),
457 })
458 }
459 } else {
460 None
461 }
462 }
463}