1use crate::{
2 context::{
3 AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
4 FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
5 SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
6 },
7 thread::{MessageId, Thread, ThreadId},
8 thread_store::ThreadStore,
9};
10use anyhow::{Context as _, Result, anyhow};
11use assistant_context::AssistantContext;
12use collections::{HashSet, IndexSet};
13use futures::{self, FutureExt};
14use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
15use language::{Buffer, File as _};
16use language_model::LanguageModelImage;
17use project::{Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file};
18use prompt_store::UserPromptId;
19use ref_cast::RefCast as _;
20use std::{
21 ops::Range,
22 path::{Path, PathBuf},
23 sync::Arc,
24};
25use text::{Anchor, OffsetRangeExt};
26
27pub struct ContextStore {
28 project: WeakEntity<Project>,
29 thread_store: Option<WeakEntity<ThreadStore>>,
30 next_context_id: ContextId,
31 context_set: IndexSet<AgentContextKey>,
32 context_thread_ids: HashSet<ThreadId>,
33 context_text_thread_paths: HashSet<Arc<Path>>,
34}
35
36pub enum ContextStoreEvent {
37 ContextRemoved(AgentContextKey),
38}
39
40impl EventEmitter<ContextStoreEvent> for ContextStore {}
41
42impl ContextStore {
43 pub fn new(
44 project: WeakEntity<Project>,
45 thread_store: Option<WeakEntity<ThreadStore>>,
46 ) -> Self {
47 Self {
48 project,
49 thread_store,
50 next_context_id: ContextId::zero(),
51 context_set: IndexSet::default(),
52 context_thread_ids: HashSet::default(),
53 context_text_thread_paths: HashSet::default(),
54 }
55 }
56
57 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
58 self.context_set.iter().map(|entry| entry.as_ref())
59 }
60
61 pub fn clear(&mut self, cx: &mut Context<Self>) {
62 self.context_set.clear();
63 self.context_thread_ids.clear();
64 cx.notify();
65 }
66
67 pub fn new_context_for_thread(
68 &self,
69 thread: &Thread,
70 exclude_messages_from_id: Option<MessageId>,
71 ) -> Vec<AgentContextHandle> {
72 let existing_context = thread
73 .messages()
74 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
75 .flat_map(|message| {
76 message
77 .loaded_context
78 .contexts
79 .iter()
80 .map(|context| AgentContextKey(context.handle()))
81 })
82 .collect::<HashSet<_>>();
83 self.context_set
84 .iter()
85 .filter(|context| !existing_context.contains(context))
86 .map(|entry| entry.0.clone())
87 .collect::<Vec<_>>()
88 }
89
90 pub fn add_file_from_path(
91 &mut self,
92 project_path: ProjectPath,
93 remove_if_exists: bool,
94 cx: &mut Context<Self>,
95 ) -> Task<Result<Option<AgentContextHandle>>> {
96 let Some(project) = self.project.upgrade() else {
97 return Task::ready(Err(anyhow!("failed to read project")));
98 };
99
100 if is_image_file(&project, &project_path, cx) {
101 self.add_image_from_path(project_path, remove_if_exists, cx)
102 } else {
103 cx.spawn(async move |this, cx| {
104 let open_buffer_task = project.update(cx, |project, cx| {
105 project.open_buffer(project_path.clone(), cx)
106 })?;
107 let buffer = open_buffer_task.await?;
108 this.update(cx, |this, cx| {
109 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
110 })
111 })
112 }
113 }
114
115 pub fn add_file_from_buffer(
116 &mut self,
117 project_path: &ProjectPath,
118 buffer: Entity<Buffer>,
119 remove_if_exists: bool,
120 cx: &mut Context<Self>,
121 ) -> Option<AgentContextHandle> {
122 let context_id = self.next_context_id.post_inc();
123 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
124
125 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
126 if remove_if_exists {
127 self.remove_context(&context, cx);
128 None
129 } else {
130 Some(key.as_ref().clone())
131 }
132 } else if self.path_included_in_directory(project_path, cx).is_some() {
133 None
134 } else {
135 self.insert_context(context.clone(), cx);
136 Some(context)
137 }
138 }
139
140 pub fn add_directory(
141 &mut self,
142 project_path: &ProjectPath,
143 remove_if_exists: bool,
144 cx: &mut Context<Self>,
145 ) -> Result<Option<AgentContextHandle>> {
146 let project = self.project.upgrade().context("failed to read project")?;
147 let entry_id = project
148 .read(cx)
149 .entry_for_path(project_path, cx)
150 .map(|entry| entry.id)
151 .context("no entry found for directory context")?;
152
153 let context_id = self.next_context_id.post_inc();
154 let context = AgentContextHandle::Directory(DirectoryContextHandle {
155 entry_id,
156 context_id,
157 });
158
159 let context =
160 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
161 if remove_if_exists {
162 self.remove_context(&context, cx);
163 None
164 } else {
165 Some(existing.as_ref().clone())
166 }
167 } else {
168 self.insert_context(context.clone(), cx);
169 Some(context)
170 };
171
172 anyhow::Ok(context)
173 }
174
175 pub fn add_symbol(
176 &mut self,
177 buffer: Entity<Buffer>,
178 symbol: SharedString,
179 range: Range<Anchor>,
180 enclosing_range: Range<Anchor>,
181 remove_if_exists: bool,
182 cx: &mut Context<Self>,
183 ) -> (Option<AgentContextHandle>, bool) {
184 let context_id = self.next_context_id.post_inc();
185 let context = AgentContextHandle::Symbol(SymbolContextHandle {
186 buffer,
187 symbol,
188 range,
189 enclosing_range,
190 context_id,
191 });
192
193 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
194 let handle = if remove_if_exists {
195 self.remove_context(&context, cx);
196 None
197 } else {
198 Some(key.as_ref().clone())
199 };
200 return (handle, false);
201 }
202
203 let included = self.insert_context(context.clone(), cx);
204 (Some(context), included)
205 }
206
207 pub fn add_thread(
208 &mut self,
209 thread: Entity<Thread>,
210 remove_if_exists: bool,
211 cx: &mut Context<Self>,
212 ) -> Option<AgentContextHandle> {
213 let context_id = self.next_context_id.post_inc();
214 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
215
216 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
217 if remove_if_exists {
218 self.remove_context(&context, cx);
219 None
220 } else {
221 Some(existing.as_ref().clone())
222 }
223 } else {
224 self.insert_context(context.clone(), cx);
225 Some(context)
226 }
227 }
228
229 pub fn add_text_thread(
230 &mut self,
231 context: Entity<AssistantContext>,
232 remove_if_exists: bool,
233 cx: &mut Context<Self>,
234 ) -> Option<AgentContextHandle> {
235 let context_id = self.next_context_id.post_inc();
236 let context = AgentContextHandle::TextThread(TextThreadContextHandle {
237 context,
238 context_id,
239 });
240
241 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
242 if remove_if_exists {
243 self.remove_context(&context, cx);
244 None
245 } else {
246 Some(existing.as_ref().clone())
247 }
248 } else {
249 self.insert_context(context.clone(), cx);
250 Some(context)
251 }
252 }
253
254 pub fn add_rules(
255 &mut self,
256 prompt_id: UserPromptId,
257 remove_if_exists: bool,
258 cx: &mut Context<ContextStore>,
259 ) -> Option<AgentContextHandle> {
260 let context_id = self.next_context_id.post_inc();
261 let context = AgentContextHandle::Rules(RulesContextHandle {
262 prompt_id,
263 context_id,
264 });
265
266 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
267 if remove_if_exists {
268 self.remove_context(&context, cx);
269 None
270 } else {
271 Some(existing.as_ref().clone())
272 }
273 } else {
274 self.insert_context(context.clone(), cx);
275 Some(context)
276 }
277 }
278
279 pub fn add_fetched_url(
280 &mut self,
281 url: String,
282 text: impl Into<SharedString>,
283 cx: &mut Context<ContextStore>,
284 ) -> AgentContextHandle {
285 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
286 url: url.into(),
287 text: text.into(),
288 context_id: self.next_context_id.post_inc(),
289 });
290
291 self.insert_context(context.clone(), cx);
292 context
293 }
294
295 pub fn add_image_from_path(
296 &mut self,
297 project_path: ProjectPath,
298 remove_if_exists: bool,
299 cx: &mut Context<ContextStore>,
300 ) -> Task<Result<Option<AgentContextHandle>>> {
301 let project = self.project.clone();
302 cx.spawn(async move |this, cx| {
303 let open_image_task = project.update(cx, |project, cx| {
304 project.open_image(project_path.clone(), cx)
305 })?;
306 let image_item = open_image_task.await?;
307
308 this.update(cx, |this, cx| {
309 let item = image_item.read(cx);
310 this.insert_image(
311 Some(item.project_path(cx)),
312 Some(item.file.full_path(cx).into()),
313 item.image.clone(),
314 remove_if_exists,
315 cx,
316 )
317 })
318 })
319 }
320
321 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
322 self.insert_image(None, None, image, false, cx);
323 }
324
325 fn insert_image(
326 &mut self,
327 project_path: Option<ProjectPath>,
328 full_path: Option<Arc<Path>>,
329 image: Arc<Image>,
330 remove_if_exists: bool,
331 cx: &mut Context<ContextStore>,
332 ) -> Option<AgentContextHandle> {
333 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
334 let context = AgentContextHandle::Image(ImageContext {
335 project_path,
336 full_path,
337 original_image: image,
338 image_task,
339 context_id: self.next_context_id.post_inc(),
340 });
341 if self.has_context(&context) && remove_if_exists {
342 self.remove_context(&context, cx);
343 return None;
344 }
345
346 self.insert_context(context.clone(), cx);
347 Some(context)
348 }
349
350 pub fn add_selection(
351 &mut self,
352 buffer: Entity<Buffer>,
353 range: Range<Anchor>,
354 cx: &mut Context<ContextStore>,
355 ) {
356 let context_id = self.next_context_id.post_inc();
357 let context = AgentContextHandle::Selection(SelectionContextHandle {
358 buffer,
359 range,
360 context_id,
361 });
362 self.insert_context(context, cx);
363 }
364
365 pub fn add_suggested_context(
366 &mut self,
367 suggested: &SuggestedContext,
368 cx: &mut Context<ContextStore>,
369 ) {
370 match suggested {
371 SuggestedContext::File {
372 buffer,
373 icon_path: _,
374 name: _,
375 } => {
376 if let Some(buffer) = buffer.upgrade() {
377 let context_id = self.next_context_id.post_inc();
378 self.insert_context(
379 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
380 cx,
381 );
382 };
383 }
384 SuggestedContext::Thread { thread, name: _ } => {
385 if let Some(thread) = thread.upgrade() {
386 let context_id = self.next_context_id.post_inc();
387 self.insert_context(
388 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
389 cx,
390 );
391 }
392 }
393 SuggestedContext::TextThread { context, name: _ } => {
394 if let Some(context) = context.upgrade() {
395 let context_id = self.next_context_id.post_inc();
396 self.insert_context(
397 AgentContextHandle::TextThread(TextThreadContextHandle {
398 context,
399 context_id,
400 }),
401 cx,
402 );
403 }
404 }
405 }
406 }
407
408 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
409 match &context {
410 AgentContextHandle::Thread(thread_context) => {
411 if let Some(thread_store) = self.thread_store.clone() {
412 thread_context.thread.update(cx, |thread, cx| {
413 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
414 });
415 self.context_thread_ids
416 .insert(thread_context.thread.read(cx).id().clone());
417 } else {
418 return false;
419 }
420 }
421 AgentContextHandle::TextThread(text_thread_context) => {
422 self.context_text_thread_paths
423 .extend(text_thread_context.context.read(cx).path().cloned());
424 }
425 _ => {}
426 }
427 let inserted = self.context_set.insert(AgentContextKey(context));
428 if inserted {
429 cx.notify();
430 }
431 inserted
432 }
433
434 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
435 if let Some((_, key)) = self
436 .context_set
437 .shift_remove_full(AgentContextKey::ref_cast(context))
438 {
439 match context {
440 AgentContextHandle::Thread(thread_context) => {
441 self.context_thread_ids
442 .remove(thread_context.thread.read(cx).id());
443 }
444 AgentContextHandle::TextThread(text_thread_context) => {
445 if let Some(path) = text_thread_context.context.read(cx).path() {
446 self.context_text_thread_paths.remove(path);
447 }
448 }
449 _ => {}
450 }
451 cx.emit(ContextStoreEvent::ContextRemoved(key));
452 cx.notify();
453 }
454 }
455
456 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
457 self.context_set
458 .contains(AgentContextKey::ref_cast(context))
459 }
460
461 /// Returns whether this file path is already included directly in the context, or if it will be
462 /// included in the context via a directory.
463 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
464 let project = self.project.upgrade()?.read(cx);
465 self.context().find_map(|context| match context {
466 AgentContextHandle::File(file_context) => {
467 FileInclusion::check_file(file_context, path, cx)
468 }
469 AgentContextHandle::Image(image_context) => {
470 FileInclusion::check_image(image_context, path)
471 }
472 AgentContextHandle::Directory(directory_context) => {
473 FileInclusion::check_directory(directory_context, path, project, cx)
474 }
475 _ => None,
476 })
477 }
478
479 pub fn path_included_in_directory(
480 &self,
481 path: &ProjectPath,
482 cx: &App,
483 ) -> Option<FileInclusion> {
484 let project = self.project.upgrade()?.read(cx);
485 self.context().find_map(|context| match context {
486 AgentContextHandle::Directory(directory_context) => {
487 FileInclusion::check_directory(directory_context, path, project, cx)
488 }
489 _ => None,
490 })
491 }
492
493 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
494 self.context().any(|context| match context {
495 AgentContextHandle::Symbol(context) => {
496 if context.symbol != symbol.name {
497 return false;
498 }
499 let buffer = context.buffer.read(cx);
500 let Some(context_path) = buffer.project_path(cx) else {
501 return false;
502 };
503 if context_path != symbol.path {
504 return false;
505 }
506 let context_range = context.range.to_point_utf16(&buffer.snapshot());
507 context_range.start == symbol.range.start.0
508 && context_range.end == symbol.range.end.0
509 }
510 _ => false,
511 })
512 }
513
514 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
515 self.context_thread_ids.contains(thread_id)
516 }
517
518 pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
519 self.context_text_thread_paths.contains(path)
520 }
521
522 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
523 self.context_set
524 .contains(&RulesContextHandle::lookup_key(prompt_id))
525 }
526
527 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
528 self.context_set
529 .contains(&FetchedUrlContext::lookup_key(url.into()))
530 }
531
532 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
533 self.context_set
534 .get(&FetchedUrlContext::lookup_key(url))
535 .map(|key| key.as_ref().clone())
536 }
537
538 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
539 self.context()
540 .filter_map(|context| match context {
541 AgentContextHandle::File(file) => {
542 let buffer = file.buffer.read(cx);
543 buffer.project_path(cx)
544 }
545 AgentContextHandle::Directory(_)
546 | AgentContextHandle::Symbol(_)
547 | AgentContextHandle::Selection(_)
548 | AgentContextHandle::FetchedUrl(_)
549 | AgentContextHandle::Thread(_)
550 | AgentContextHandle::TextThread(_)
551 | AgentContextHandle::Rules(_)
552 | AgentContextHandle::Image(_) => None,
553 })
554 .collect()
555 }
556
557 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
558 &self.context_thread_ids
559 }
560}
561
562#[derive(Clone)]
563pub enum SuggestedContext {
564 File {
565 name: SharedString,
566 icon_path: Option<SharedString>,
567 buffer: WeakEntity<Buffer>,
568 },
569 Thread {
570 name: SharedString,
571 thread: WeakEntity<Thread>,
572 },
573 TextThread {
574 name: SharedString,
575 context: WeakEntity<AssistantContext>,
576 },
577}
578
579impl SuggestedContext {
580 pub fn name(&self) -> &SharedString {
581 match self {
582 Self::File { name, .. } => name,
583 Self::Thread { name, .. } => name,
584 Self::TextThread { name, .. } => name,
585 }
586 }
587
588 pub fn icon_path(&self) -> Option<SharedString> {
589 match self {
590 Self::File { icon_path, .. } => icon_path.clone(),
591 Self::Thread { .. } => None,
592 Self::TextThread { .. } => None,
593 }
594 }
595
596 pub fn kind(&self) -> ContextKind {
597 match self {
598 Self::File { .. } => ContextKind::File,
599 Self::Thread { .. } => ContextKind::Thread,
600 Self::TextThread { .. } => ContextKind::TextThread,
601 }
602 }
603}
604
605pub enum FileInclusion {
606 Direct,
607 InDirectory { full_path: PathBuf },
608}
609
610impl FileInclusion {
611 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
612 let file_path = file_context.buffer.read(cx).project_path(cx)?;
613 if path == &file_path {
614 Some(FileInclusion::Direct)
615 } else {
616 None
617 }
618 }
619
620 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
621 let image_path = image_context.project_path.as_ref()?;
622 if path == image_path {
623 Some(FileInclusion::Direct)
624 } else {
625 None
626 }
627 }
628
629 fn check_directory(
630 directory_context: &DirectoryContextHandle,
631 path: &ProjectPath,
632 project: &Project,
633 cx: &App,
634 ) -> Option<Self> {
635 let worktree = project
636 .worktree_for_entry(directory_context.entry_id, cx)?
637 .read(cx);
638 let entry = worktree.entry_for_id(directory_context.entry_id)?;
639 let directory_path = ProjectPath {
640 worktree_id: worktree.id(),
641 path: entry.path.clone(),
642 };
643 if path.starts_with(&directory_path) {
644 if path == &directory_path {
645 Some(FileInclusion::Direct)
646 } else {
647 Some(FileInclusion::InDirectory {
648 full_path: worktree.full_path(&entry.path),
649 })
650 }
651 } else {
652 None
653 }
654 }
655}