1use crate::{
2 MessageId, ThreadId,
3 context::{
4 AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
5 FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
6 SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
7 },
8 thread::Thread,
9 thread_store::ThreadStore,
10};
11use anyhow::{Context as _, Result, anyhow};
12use assistant_context::AssistantContext;
13use collections::{HashSet, IndexSet};
14use futures::{self, FutureExt};
15use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
16use language::{Buffer, File as _};
17use language_model::LanguageModelImage;
18use project::{Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file};
19use prompt_store::UserPromptId;
20use ref_cast::RefCast as _;
21use std::{
22 ops::Range,
23 path::{Path, PathBuf},
24 sync::Arc,
25};
26use text::{Anchor, OffsetRangeExt};
27
28pub struct ContextStore {
29 project: WeakEntity<Project>,
30 thread_store: Option<WeakEntity<ThreadStore>>,
31 next_context_id: ContextId,
32 context_set: IndexSet<AgentContextKey>,
33 context_thread_ids: HashSet<ThreadId>,
34 context_text_thread_paths: HashSet<Arc<Path>>,
35}
36
37pub enum ContextStoreEvent {
38 ContextRemoved(AgentContextKey),
39}
40
41impl EventEmitter<ContextStoreEvent> for ContextStore {}
42
43impl ContextStore {
44 pub fn new(
45 project: WeakEntity<Project>,
46 thread_store: Option<WeakEntity<ThreadStore>>,
47 ) -> Self {
48 Self {
49 project,
50 thread_store,
51 next_context_id: ContextId::zero(),
52 context_set: IndexSet::default(),
53 context_thread_ids: HashSet::default(),
54 context_text_thread_paths: HashSet::default(),
55 }
56 }
57
58 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
59 self.context_set.iter().map(|entry| entry.as_ref())
60 }
61
62 pub fn clear(&mut self, cx: &mut Context<Self>) {
63 self.context_set.clear();
64 self.context_thread_ids.clear();
65 cx.notify();
66 }
67
68 pub fn new_context_for_thread(
69 &self,
70 thread: &Thread,
71 exclude_messages_from_id: Option<MessageId>,
72 ) -> Vec<AgentContextHandle> {
73 let existing_context = thread
74 .messages()
75 .iter()
76 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
77 .flat_map(|message| {
78 message
79 .loaded_context
80 .contexts
81 .iter()
82 .map(|context| AgentContextKey(context.handle()))
83 })
84 .collect::<HashSet<_>>();
85 self.context_set
86 .iter()
87 .filter(|context| !existing_context.contains(context))
88 .map(|entry| entry.0.clone())
89 .collect::<Vec<_>>()
90 }
91
92 pub fn add_file_from_path(
93 &mut self,
94 project_path: ProjectPath,
95 remove_if_exists: bool,
96 cx: &mut Context<Self>,
97 ) -> Task<Result<Option<AgentContextHandle>>> {
98 let Some(project) = self.project.upgrade() else {
99 return Task::ready(Err(anyhow!("failed to read project")));
100 };
101
102 if is_image_file(&project, &project_path, cx) {
103 self.add_image_from_path(project_path, remove_if_exists, cx)
104 } else {
105 cx.spawn(async move |this, cx| {
106 let open_buffer_task = project.update(cx, |project, cx| {
107 project.open_buffer(project_path.clone(), cx)
108 })?;
109 let buffer = open_buffer_task.await?;
110 this.update(cx, |this, cx| {
111 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
112 })
113 })
114 }
115 }
116
117 pub fn add_file_from_buffer(
118 &mut self,
119 project_path: &ProjectPath,
120 buffer: Entity<Buffer>,
121 remove_if_exists: bool,
122 cx: &mut Context<Self>,
123 ) -> Option<AgentContextHandle> {
124 let context_id = self.next_context_id.post_inc();
125 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
126
127 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
128 if remove_if_exists {
129 self.remove_context(&context, cx);
130 None
131 } else {
132 Some(key.as_ref().clone())
133 }
134 } else if self.path_included_in_directory(project_path, cx).is_some() {
135 None
136 } else {
137 self.insert_context(context.clone(), cx);
138 Some(context)
139 }
140 }
141
142 pub fn add_directory(
143 &mut self,
144 project_path: &ProjectPath,
145 remove_if_exists: bool,
146 cx: &mut Context<Self>,
147 ) -> Result<Option<AgentContextHandle>> {
148 let project = self.project.upgrade().context("failed to read project")?;
149 let entry_id = project
150 .read(cx)
151 .entry_for_path(project_path, cx)
152 .map(|entry| entry.id)
153 .context("no entry found for directory context")?;
154
155 let context_id = self.next_context_id.post_inc();
156 let context = AgentContextHandle::Directory(DirectoryContextHandle {
157 entry_id,
158 context_id,
159 });
160
161 let context =
162 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
163 if remove_if_exists {
164 self.remove_context(&context, cx);
165 None
166 } else {
167 Some(existing.as_ref().clone())
168 }
169 } else {
170 self.insert_context(context.clone(), cx);
171 Some(context)
172 };
173
174 anyhow::Ok(context)
175 }
176
177 pub fn add_symbol(
178 &mut self,
179 buffer: Entity<Buffer>,
180 symbol: SharedString,
181 range: Range<Anchor>,
182 enclosing_range: Range<Anchor>,
183 remove_if_exists: bool,
184 cx: &mut Context<Self>,
185 ) -> (Option<AgentContextHandle>, bool) {
186 let context_id = self.next_context_id.post_inc();
187 let context = AgentContextHandle::Symbol(SymbolContextHandle {
188 buffer,
189 symbol,
190 range,
191 enclosing_range,
192 context_id,
193 });
194
195 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
196 let handle = if remove_if_exists {
197 self.remove_context(&context, cx);
198 None
199 } else {
200 Some(key.as_ref().clone())
201 };
202 return (handle, false);
203 }
204
205 let included = self.insert_context(context.clone(), cx);
206 (Some(context), included)
207 }
208
209 pub fn add_thread(
210 &mut self,
211 thread: Entity<Thread>,
212 remove_if_exists: bool,
213 cx: &mut Context<Self>,
214 ) -> Option<AgentContextHandle> {
215 let context_id = self.next_context_id.post_inc();
216 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
217
218 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
219 if remove_if_exists {
220 self.remove_context(&context, cx);
221 None
222 } else {
223 Some(existing.as_ref().clone())
224 }
225 } else {
226 self.insert_context(context.clone(), cx);
227 Some(context)
228 }
229 }
230
231 pub fn add_text_thread(
232 &mut self,
233 context: Entity<AssistantContext>,
234 remove_if_exists: bool,
235 cx: &mut Context<Self>,
236 ) -> Option<AgentContextHandle> {
237 let context_id = self.next_context_id.post_inc();
238 let context = AgentContextHandle::TextThread(TextThreadContextHandle {
239 context,
240 context_id,
241 });
242
243 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
244 if remove_if_exists {
245 self.remove_context(&context, cx);
246 None
247 } else {
248 Some(existing.as_ref().clone())
249 }
250 } else {
251 self.insert_context(context.clone(), cx);
252 Some(context)
253 }
254 }
255
256 pub fn add_rules(
257 &mut self,
258 prompt_id: UserPromptId,
259 remove_if_exists: bool,
260 cx: &mut Context<ContextStore>,
261 ) -> Option<AgentContextHandle> {
262 let context_id = self.next_context_id.post_inc();
263 let context = AgentContextHandle::Rules(RulesContextHandle {
264 prompt_id,
265 context_id,
266 });
267
268 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
269 if remove_if_exists {
270 self.remove_context(&context, cx);
271 None
272 } else {
273 Some(existing.as_ref().clone())
274 }
275 } else {
276 self.insert_context(context.clone(), cx);
277 Some(context)
278 }
279 }
280
281 pub fn add_fetched_url(
282 &mut self,
283 url: String,
284 text: impl Into<SharedString>,
285 cx: &mut Context<ContextStore>,
286 ) -> AgentContextHandle {
287 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
288 url: url.into(),
289 text: text.into(),
290 context_id: self.next_context_id.post_inc(),
291 });
292
293 self.insert_context(context.clone(), cx);
294 context
295 }
296
297 pub fn add_image_from_path(
298 &mut self,
299 project_path: ProjectPath,
300 remove_if_exists: bool,
301 cx: &mut Context<ContextStore>,
302 ) -> Task<Result<Option<AgentContextHandle>>> {
303 let project = self.project.clone();
304 cx.spawn(async move |this, cx| {
305 let open_image_task = project.update(cx, |project, cx| {
306 project.open_image(project_path.clone(), cx)
307 })?;
308 let image_item = open_image_task.await?;
309
310 this.update(cx, |this, cx| {
311 let item = image_item.read(cx);
312 this.insert_image(
313 Some(item.project_path(cx)),
314 Some(item.file.full_path(cx).into()),
315 item.image.clone(),
316 remove_if_exists,
317 cx,
318 )
319 })
320 })
321 }
322
323 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
324 self.insert_image(None, None, image, false, cx);
325 }
326
327 fn insert_image(
328 &mut self,
329 project_path: Option<ProjectPath>,
330 full_path: Option<Arc<Path>>,
331 image: Arc<Image>,
332 remove_if_exists: bool,
333 cx: &mut Context<ContextStore>,
334 ) -> Option<AgentContextHandle> {
335 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
336 let context = AgentContextHandle::Image(ImageContext {
337 project_path,
338 full_path,
339 original_image: image,
340 image_task,
341 context_id: self.next_context_id.post_inc(),
342 });
343 if self.has_context(&context) {
344 if remove_if_exists {
345 self.remove_context(&context, cx);
346 return None;
347 }
348 }
349
350 self.insert_context(context.clone(), cx);
351 Some(context)
352 }
353
354 pub fn add_selection(
355 &mut self,
356 buffer: Entity<Buffer>,
357 range: Range<Anchor>,
358 cx: &mut Context<ContextStore>,
359 ) {
360 let context_id = self.next_context_id.post_inc();
361 let context = AgentContextHandle::Selection(SelectionContextHandle {
362 buffer,
363 range,
364 context_id,
365 });
366 self.insert_context(context, cx);
367 }
368
369 pub fn add_suggested_context(
370 &mut self,
371 suggested: &SuggestedContext,
372 cx: &mut Context<ContextStore>,
373 ) {
374 match suggested {
375 SuggestedContext::File {
376 buffer,
377 icon_path: _,
378 name: _,
379 } => {
380 if let Some(buffer) = buffer.upgrade() {
381 let context_id = self.next_context_id.post_inc();
382 self.insert_context(
383 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
384 cx,
385 );
386 };
387 }
388 SuggestedContext::Thread { thread, name: _ } => {
389 if let Some(thread) = thread.upgrade() {
390 let context_id = self.next_context_id.post_inc();
391 self.insert_context(
392 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
393 cx,
394 );
395 }
396 }
397 SuggestedContext::TextThread { context, name: _ } => {
398 if let Some(context) = context.upgrade() {
399 let context_id = self.next_context_id.post_inc();
400 self.insert_context(
401 AgentContextHandle::TextThread(TextThreadContextHandle {
402 context,
403 context_id,
404 }),
405 cx,
406 );
407 }
408 }
409 }
410 }
411
412 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
413 match &context {
414 AgentContextHandle::Thread(thread_context) => {
415 if let Some(thread_store) = self.thread_store.clone() {
416 thread_context.thread.update(cx, |thread, cx| {
417 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
418 });
419 self.context_thread_ids
420 .insert(thread_context.thread.read(cx).id().clone());
421 } else {
422 return false;
423 }
424 }
425 AgentContextHandle::TextThread(text_thread_context) => {
426 self.context_text_thread_paths
427 .extend(text_thread_context.context.read(cx).path().cloned());
428 }
429 _ => {}
430 }
431 let inserted = self.context_set.insert(AgentContextKey(context));
432 if inserted {
433 cx.notify();
434 }
435 inserted
436 }
437
438 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
439 if let Some((_, key)) = self
440 .context_set
441 .shift_remove_full(AgentContextKey::ref_cast(context))
442 {
443 match context {
444 AgentContextHandle::Thread(thread_context) => {
445 self.context_thread_ids
446 .remove(&thread_context.thread.read(cx).id());
447 }
448 AgentContextHandle::TextThread(text_thread_context) => {
449 if let Some(path) = text_thread_context.context.read(cx).path() {
450 self.context_text_thread_paths.remove(path);
451 }
452 }
453 _ => {}
454 }
455 cx.emit(ContextStoreEvent::ContextRemoved(key));
456 cx.notify();
457 }
458 }
459
460 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
461 self.context_set
462 .contains(AgentContextKey::ref_cast(context))
463 }
464
465 /// Returns whether this file path is already included directly in the context, or if it will be
466 /// included in the context via a directory.
467 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
468 let project = self.project.upgrade()?.read(cx);
469 self.context().find_map(|context| match context {
470 AgentContextHandle::File(file_context) => {
471 FileInclusion::check_file(file_context, path, cx)
472 }
473 AgentContextHandle::Image(image_context) => {
474 FileInclusion::check_image(image_context, path)
475 }
476 AgentContextHandle::Directory(directory_context) => {
477 FileInclusion::check_directory(directory_context, path, project, cx)
478 }
479 _ => None,
480 })
481 }
482
483 pub fn path_included_in_directory(
484 &self,
485 path: &ProjectPath,
486 cx: &App,
487 ) -> Option<FileInclusion> {
488 let project = self.project.upgrade()?.read(cx);
489 self.context().find_map(|context| match context {
490 AgentContextHandle::Directory(directory_context) => {
491 FileInclusion::check_directory(directory_context, path, project, cx)
492 }
493 _ => None,
494 })
495 }
496
497 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
498 self.context().any(|context| match context {
499 AgentContextHandle::Symbol(context) => {
500 if context.symbol != symbol.name {
501 return false;
502 }
503 let buffer = context.buffer.read(cx);
504 let Some(context_path) = buffer.project_path(cx) else {
505 return false;
506 };
507 if context_path != symbol.path {
508 return false;
509 }
510 let context_range = context.range.to_point_utf16(&buffer.snapshot());
511 context_range.start == symbol.range.start.0
512 && context_range.end == symbol.range.end.0
513 }
514 _ => false,
515 })
516 }
517
518 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
519 self.context_thread_ids.contains(thread_id)
520 }
521
522 pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
523 self.context_text_thread_paths.contains(path)
524 }
525
526 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
527 self.context_set
528 .contains(&RulesContextHandle::lookup_key(prompt_id))
529 }
530
531 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
532 self.context_set
533 .contains(&FetchedUrlContext::lookup_key(url.into()))
534 }
535
536 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
537 self.context_set
538 .get(&FetchedUrlContext::lookup_key(url))
539 .map(|key| key.as_ref().clone())
540 }
541
542 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
543 self.context()
544 .filter_map(|context| match context {
545 AgentContextHandle::File(file) => {
546 let buffer = file.buffer.read(cx);
547 buffer.project_path(cx)
548 }
549 AgentContextHandle::Directory(_)
550 | AgentContextHandle::Symbol(_)
551 | AgentContextHandle::Selection(_)
552 | AgentContextHandle::FetchedUrl(_)
553 | AgentContextHandle::Thread(_)
554 | AgentContextHandle::TextThread(_)
555 | AgentContextHandle::Rules(_)
556 | AgentContextHandle::Image(_) => None,
557 })
558 .collect()
559 }
560
561 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
562 &self.context_thread_ids
563 }
564}
565
566#[derive(Clone)]
567pub enum SuggestedContext {
568 File {
569 name: SharedString,
570 icon_path: Option<SharedString>,
571 buffer: WeakEntity<Buffer>,
572 },
573 Thread {
574 name: SharedString,
575 thread: WeakEntity<Thread>,
576 },
577 TextThread {
578 name: SharedString,
579 context: WeakEntity<AssistantContext>,
580 },
581}
582
583impl SuggestedContext {
584 pub fn name(&self) -> &SharedString {
585 match self {
586 Self::File { name, .. } => name,
587 Self::Thread { name, .. } => name,
588 Self::TextThread { name, .. } => name,
589 }
590 }
591
592 pub fn icon_path(&self) -> Option<SharedString> {
593 match self {
594 Self::File { icon_path, .. } => icon_path.clone(),
595 Self::Thread { .. } => None,
596 Self::TextThread { .. } => None,
597 }
598 }
599
600 pub fn kind(&self) -> ContextKind {
601 match self {
602 Self::File { .. } => ContextKind::File,
603 Self::Thread { .. } => ContextKind::Thread,
604 Self::TextThread { .. } => ContextKind::TextThread,
605 }
606 }
607}
608
609pub enum FileInclusion {
610 Direct,
611 InDirectory { full_path: PathBuf },
612}
613
614impl FileInclusion {
615 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
616 let file_path = file_context.buffer.read(cx).project_path(cx)?;
617 if path == &file_path {
618 Some(FileInclusion::Direct)
619 } else {
620 None
621 }
622 }
623
624 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
625 let image_path = image_context.project_path.as_ref()?;
626 if path == image_path {
627 Some(FileInclusion::Direct)
628 } else {
629 None
630 }
631 }
632
633 fn check_directory(
634 directory_context: &DirectoryContextHandle,
635 path: &ProjectPath,
636 project: &Project,
637 cx: &App,
638 ) -> Option<Self> {
639 let worktree = project
640 .worktree_for_entry(directory_context.entry_id, cx)?
641 .read(cx);
642 let entry = worktree.entry_for_id(directory_context.entry_id)?;
643 let directory_path = ProjectPath {
644 worktree_id: worktree.id(),
645 path: entry.path.clone(),
646 };
647 if path.starts_with(&directory_path) {
648 if path == &directory_path {
649 Some(FileInclusion::Direct)
650 } else {
651 Some(FileInclusion::InDirectory {
652 full_path: worktree.full_path(&entry.path),
653 })
654 }
655 } else {
656 None
657 }
658 }
659}