1use std::ops::Range;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use assistant_context_editor::AssistantContext;
7use collections::{HashSet, IndexSet};
8use futures::{self, FutureExt};
9use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
10use language::Buffer;
11use language_model::LanguageModelImage;
12use project::image_store::is_image_file;
13use project::{Project, ProjectItem, ProjectPath, Symbol};
14use prompt_store::UserPromptId;
15use ref_cast::RefCast as _;
16use text::{Anchor, OffsetRangeExt};
17
18use crate::ThreadStore;
19use crate::context::{
20 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
21 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
22 SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
23};
24use crate::context_strip::SuggestedContext;
25use crate::thread::{MessageId, Thread, ThreadId};
26
27pub struct ContextStore {
28 project: WeakEntity<Project>,
29 thread_store: Option<WeakEntity<ThreadStore>>,
30 next_context_id: ContextId,
31 context_set: IndexSet<AgentContextKey>,
32 context_thread_ids: HashSet<ThreadId>,
33 context_text_thread_paths: HashSet<Arc<Path>>,
34}
35
36pub enum ContextStoreEvent {
37 ContextRemoved(AgentContextKey),
38}
39
40impl EventEmitter<ContextStoreEvent> for ContextStore {}
41
42impl ContextStore {
43 pub fn new(
44 project: WeakEntity<Project>,
45 thread_store: Option<WeakEntity<ThreadStore>>,
46 ) -> Self {
47 Self {
48 project,
49 thread_store,
50 next_context_id: ContextId::zero(),
51 context_set: IndexSet::default(),
52 context_thread_ids: HashSet::default(),
53 context_text_thread_paths: HashSet::default(),
54 }
55 }
56
57 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
58 self.context_set.iter().map(|entry| entry.as_ref())
59 }
60
61 pub fn clear(&mut self) {
62 self.context_set.clear();
63 self.context_thread_ids.clear();
64 }
65
66 pub fn new_context_for_thread(
67 &self,
68 thread: &Thread,
69 exclude_messages_from_id: Option<MessageId>,
70 ) -> Vec<AgentContextHandle> {
71 let existing_context = thread
72 .messages()
73 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
74 .flat_map(|message| {
75 message
76 .loaded_context
77 .contexts
78 .iter()
79 .map(|context| AgentContextKey(context.handle()))
80 })
81 .collect::<HashSet<_>>();
82 self.context_set
83 .iter()
84 .filter(|context| !existing_context.contains(context))
85 .map(|entry| entry.0.clone())
86 .collect::<Vec<_>>()
87 }
88
89 pub fn add_file_from_path(
90 &mut self,
91 project_path: ProjectPath,
92 remove_if_exists: bool,
93 cx: &mut Context<Self>,
94 ) -> Task<Result<Option<AgentContextHandle>>> {
95 let Some(project) = self.project.upgrade() else {
96 return Task::ready(Err(anyhow!("failed to read project")));
97 };
98
99 if is_image_file(&project, &project_path, cx) {
100 self.add_image_from_path(project_path, remove_if_exists, cx)
101 } else {
102 cx.spawn(async move |this, cx| {
103 let open_buffer_task = project.update(cx, |project, cx| {
104 project.open_buffer(project_path.clone(), cx)
105 })?;
106 let buffer = open_buffer_task.await?;
107 this.update(cx, |this, cx| {
108 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
109 })
110 })
111 }
112 }
113
114 pub fn add_file_from_buffer(
115 &mut self,
116 project_path: &ProjectPath,
117 buffer: Entity<Buffer>,
118 remove_if_exists: bool,
119 cx: &mut Context<Self>,
120 ) -> Option<AgentContextHandle> {
121 let context_id = self.next_context_id.post_inc();
122 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
123
124 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
125 if remove_if_exists {
126 self.remove_context(&context, cx);
127 None
128 } else {
129 Some(key.as_ref().clone())
130 }
131 } else if self.path_included_in_directory(project_path, cx).is_some() {
132 None
133 } else {
134 self.insert_context(context.clone(), cx);
135 Some(context)
136 }
137 }
138
139 pub fn add_directory(
140 &mut self,
141 project_path: &ProjectPath,
142 remove_if_exists: bool,
143 cx: &mut Context<Self>,
144 ) -> Result<Option<AgentContextHandle>> {
145 let Some(project) = self.project.upgrade() else {
146 return Err(anyhow!("failed to read project"));
147 };
148
149 let Some(entry_id) = project
150 .read(cx)
151 .entry_for_path(project_path, cx)
152 .map(|entry| entry.id)
153 else {
154 return Err(anyhow!("no entry found for directory context"));
155 };
156
157 let context_id = self.next_context_id.post_inc();
158 let context = AgentContextHandle::Directory(DirectoryContextHandle {
159 entry_id,
160 context_id,
161 });
162
163 let context =
164 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
165 if remove_if_exists {
166 self.remove_context(&context, cx);
167 None
168 } else {
169 Some(existing.as_ref().clone())
170 }
171 } else {
172 self.insert_context(context.clone(), cx);
173 Some(context)
174 };
175
176 anyhow::Ok(context)
177 }
178
179 pub fn add_symbol(
180 &mut self,
181 buffer: Entity<Buffer>,
182 symbol: SharedString,
183 range: Range<Anchor>,
184 enclosing_range: Range<Anchor>,
185 remove_if_exists: bool,
186 cx: &mut Context<Self>,
187 ) -> (Option<AgentContextHandle>, bool) {
188 let context_id = self.next_context_id.post_inc();
189 let context = AgentContextHandle::Symbol(SymbolContextHandle {
190 buffer,
191 symbol,
192 range,
193 enclosing_range,
194 context_id,
195 });
196
197 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
198 let handle = if remove_if_exists {
199 self.remove_context(&context, cx);
200 None
201 } else {
202 Some(key.as_ref().clone())
203 };
204 return (handle, false);
205 }
206
207 let included = self.insert_context(context.clone(), cx);
208 (Some(context), included)
209 }
210
211 pub fn add_thread(
212 &mut self,
213 thread: Entity<Thread>,
214 remove_if_exists: bool,
215 cx: &mut Context<Self>,
216 ) -> Option<AgentContextHandle> {
217 let context_id = self.next_context_id.post_inc();
218 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
219
220 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
221 if remove_if_exists {
222 self.remove_context(&context, cx);
223 None
224 } else {
225 Some(existing.as_ref().clone())
226 }
227 } else {
228 self.insert_context(context.clone(), cx);
229 Some(context)
230 }
231 }
232
233 pub fn add_text_thread(
234 &mut self,
235 context: Entity<AssistantContext>,
236 remove_if_exists: bool,
237 cx: &mut Context<Self>,
238 ) -> Option<AgentContextHandle> {
239 let context_id = self.next_context_id.post_inc();
240 let context = AgentContextHandle::TextThread(TextThreadContextHandle {
241 context,
242 context_id,
243 });
244
245 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
246 if remove_if_exists {
247 self.remove_context(&context, cx);
248 None
249 } else {
250 Some(existing.as_ref().clone())
251 }
252 } else {
253 self.insert_context(context.clone(), cx);
254 Some(context)
255 }
256 }
257
258 pub fn add_rules(
259 &mut self,
260 prompt_id: UserPromptId,
261 remove_if_exists: bool,
262 cx: &mut Context<ContextStore>,
263 ) -> Option<AgentContextHandle> {
264 let context_id = self.next_context_id.post_inc();
265 let context = AgentContextHandle::Rules(RulesContextHandle {
266 prompt_id,
267 context_id,
268 });
269
270 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
271 if remove_if_exists {
272 self.remove_context(&context, cx);
273 None
274 } else {
275 Some(existing.as_ref().clone())
276 }
277 } else {
278 self.insert_context(context.clone(), cx);
279 Some(context)
280 }
281 }
282
283 pub fn add_fetched_url(
284 &mut self,
285 url: String,
286 text: impl Into<SharedString>,
287 cx: &mut Context<ContextStore>,
288 ) -> AgentContextHandle {
289 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
290 url: url.into(),
291 text: text.into(),
292 context_id: self.next_context_id.post_inc(),
293 });
294
295 self.insert_context(context.clone(), cx);
296 context
297 }
298
299 pub fn add_image_from_path(
300 &mut self,
301 project_path: ProjectPath,
302 remove_if_exists: bool,
303 cx: &mut Context<ContextStore>,
304 ) -> Task<Result<Option<AgentContextHandle>>> {
305 let project = self.project.clone();
306 cx.spawn(async move |this, cx| {
307 let open_image_task = project.update(cx, |project, cx| {
308 project.open_image(project_path.clone(), cx)
309 })?;
310 let image_item = open_image_task.await?;
311 let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
312 this.update(cx, |this, cx| {
313 this.insert_image(
314 Some(image_item.read(cx).project_path(cx)),
315 image,
316 remove_if_exists,
317 cx,
318 )
319 })
320 })
321 }
322
323 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
324 self.insert_image(None, image, false, cx);
325 }
326
327 fn insert_image(
328 &mut self,
329 project_path: Option<ProjectPath>,
330 image: Arc<Image>,
331 remove_if_exists: bool,
332 cx: &mut Context<ContextStore>,
333 ) -> Option<AgentContextHandle> {
334 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
335 let context = AgentContextHandle::Image(ImageContext {
336 project_path,
337 original_image: image,
338 image_task,
339 context_id: self.next_context_id.post_inc(),
340 });
341 if self.has_context(&context) {
342 if remove_if_exists {
343 self.remove_context(&context, cx);
344 return None;
345 }
346 }
347
348 self.insert_context(context.clone(), cx);
349 Some(context)
350 }
351
352 pub fn add_selection(
353 &mut self,
354 buffer: Entity<Buffer>,
355 range: Range<Anchor>,
356 cx: &mut Context<ContextStore>,
357 ) {
358 let context_id = self.next_context_id.post_inc();
359 let context = AgentContextHandle::Selection(SelectionContextHandle {
360 buffer,
361 range,
362 context_id,
363 });
364 self.insert_context(context, cx);
365 }
366
367 pub fn add_suggested_context(
368 &mut self,
369 suggested: &SuggestedContext,
370 cx: &mut Context<ContextStore>,
371 ) {
372 match suggested {
373 SuggestedContext::File {
374 buffer,
375 icon_path: _,
376 name: _,
377 } => {
378 if let Some(buffer) = buffer.upgrade() {
379 let context_id = self.next_context_id.post_inc();
380 self.insert_context(
381 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
382 cx,
383 );
384 };
385 }
386 SuggestedContext::Thread { thread, name: _ } => {
387 if let Some(thread) = thread.upgrade() {
388 let context_id = self.next_context_id.post_inc();
389 self.insert_context(
390 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
391 cx,
392 );
393 }
394 }
395 SuggestedContext::TextThread { context, name: _ } => {
396 if let Some(context) = context.upgrade() {
397 let context_id = self.next_context_id.post_inc();
398 self.insert_context(
399 AgentContextHandle::TextThread(TextThreadContextHandle {
400 context,
401 context_id,
402 }),
403 cx,
404 );
405 }
406 }
407 }
408 }
409
410 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
411 match &context {
412 AgentContextHandle::Thread(thread_context) => {
413 if let Some(thread_store) = self.thread_store.clone() {
414 thread_context.thread.update(cx, |thread, cx| {
415 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
416 });
417 self.context_thread_ids
418 .insert(thread_context.thread.read(cx).id().clone());
419 } else {
420 return false;
421 }
422 }
423 AgentContextHandle::TextThread(text_thread_context) => {
424 self.context_text_thread_paths
425 .extend(text_thread_context.context.read(cx).path().cloned());
426 }
427 _ => {}
428 }
429 let inserted = self.context_set.insert(AgentContextKey(context));
430 if inserted {
431 cx.notify();
432 }
433 inserted
434 }
435
436 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
437 if let Some((_, key)) = self
438 .context_set
439 .shift_remove_full(AgentContextKey::ref_cast(context))
440 {
441 match context {
442 AgentContextHandle::Thread(thread_context) => {
443 self.context_thread_ids
444 .remove(thread_context.thread.read(cx).id());
445 }
446 AgentContextHandle::TextThread(text_thread_context) => {
447 if let Some(path) = text_thread_context.context.read(cx).path() {
448 self.context_text_thread_paths.remove(path);
449 }
450 }
451 _ => {}
452 }
453 cx.emit(ContextStoreEvent::ContextRemoved(key));
454 cx.notify();
455 }
456 }
457
458 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
459 self.context_set
460 .contains(AgentContextKey::ref_cast(context))
461 }
462
463 /// Returns whether this file path is already included directly in the context, or if it will be
464 /// included in the context via a directory.
465 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
466 let project = self.project.upgrade()?.read(cx);
467 self.context().find_map(|context| match context {
468 AgentContextHandle::File(file_context) => {
469 FileInclusion::check_file(file_context, path, cx)
470 }
471 AgentContextHandle::Image(image_context) => {
472 FileInclusion::check_image(image_context, path)
473 }
474 AgentContextHandle::Directory(directory_context) => {
475 FileInclusion::check_directory(directory_context, path, project, cx)
476 }
477 _ => None,
478 })
479 }
480
481 pub fn path_included_in_directory(
482 &self,
483 path: &ProjectPath,
484 cx: &App,
485 ) -> Option<FileInclusion> {
486 let project = self.project.upgrade()?.read(cx);
487 self.context().find_map(|context| match context {
488 AgentContextHandle::Directory(directory_context) => {
489 FileInclusion::check_directory(directory_context, path, project, cx)
490 }
491 _ => None,
492 })
493 }
494
495 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
496 self.context().any(|context| match context {
497 AgentContextHandle::Symbol(context) => {
498 if context.symbol != symbol.name {
499 return false;
500 }
501 let buffer = context.buffer.read(cx);
502 let Some(context_path) = buffer.project_path(cx) else {
503 return false;
504 };
505 if context_path != symbol.path {
506 return false;
507 }
508 let context_range = context.range.to_point_utf16(&buffer.snapshot());
509 context_range.start == symbol.range.start.0
510 && context_range.end == symbol.range.end.0
511 }
512 _ => false,
513 })
514 }
515
516 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
517 self.context_thread_ids.contains(thread_id)
518 }
519
520 pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
521 self.context_text_thread_paths.contains(path)
522 }
523
524 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
525 self.context_set
526 .contains(&RulesContextHandle::lookup_key(prompt_id))
527 }
528
529 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
530 self.context_set
531 .contains(&FetchedUrlContext::lookup_key(url.into()))
532 }
533
534 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
535 self.context_set
536 .get(&FetchedUrlContext::lookup_key(url))
537 .map(|key| key.as_ref().clone())
538 }
539
540 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
541 self.context()
542 .filter_map(|context| match context {
543 AgentContextHandle::File(file) => {
544 let buffer = file.buffer.read(cx);
545 buffer.project_path(cx)
546 }
547 AgentContextHandle::Directory(_)
548 | AgentContextHandle::Symbol(_)
549 | AgentContextHandle::Selection(_)
550 | AgentContextHandle::FetchedUrl(_)
551 | AgentContextHandle::Thread(_)
552 | AgentContextHandle::TextThread(_)
553 | AgentContextHandle::Rules(_)
554 | AgentContextHandle::Image(_) => None,
555 })
556 .collect()
557 }
558
559 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
560 &self.context_thread_ids
561 }
562}
563
564pub enum FileInclusion {
565 Direct,
566 InDirectory { full_path: PathBuf },
567}
568
569impl FileInclusion {
570 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
571 let file_path = file_context.buffer.read(cx).project_path(cx)?;
572 if path == &file_path {
573 Some(FileInclusion::Direct)
574 } else {
575 None
576 }
577 }
578
579 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
580 let image_path = image_context.project_path.as_ref()?;
581 if path == image_path {
582 Some(FileInclusion::Direct)
583 } else {
584 None
585 }
586 }
587
588 fn check_directory(
589 directory_context: &DirectoryContextHandle,
590 path: &ProjectPath,
591 project: &Project,
592 cx: &App,
593 ) -> Option<Self> {
594 let worktree = project
595 .worktree_for_entry(directory_context.entry_id, cx)?
596 .read(cx);
597 let entry = worktree.entry_for_id(directory_context.entry_id)?;
598 let directory_path = ProjectPath {
599 worktree_id: worktree.id(),
600 path: entry.path.clone(),
601 };
602 if path.starts_with(&directory_path) {
603 if path == &directory_path {
604 Some(FileInclusion::Direct)
605 } else {
606 Some(FileInclusion::InDirectory {
607 full_path: worktree.full_path(&entry.path),
608 })
609 }
610 } else {
611 None
612 }
613 }
614}