1use crate::{
2 context::{
3 AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
4 FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
5 SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
6 },
7 thread::{MessageId, Thread, ThreadId},
8 thread_store::ThreadStore,
9};
10use anyhow::{Context as _, Result, anyhow};
11use assistant_context::AssistantContext;
12use collections::{HashSet, IndexSet};
13use futures::{self, FutureExt};
14use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
15use language::{Buffer, File as _};
16use language_model::LanguageModelImage;
17use project::{
18 Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file,
19 lsp_store::SymbolLocation,
20};
21use prompt_store::UserPromptId;
22use ref_cast::RefCast as _;
23use std::{
24 ops::Range,
25 path::{Path, PathBuf},
26 sync::Arc,
27};
28use text::{Anchor, OffsetRangeExt};
29
30pub struct ContextStore {
31 project: WeakEntity<Project>,
32 thread_store: Option<WeakEntity<ThreadStore>>,
33 next_context_id: ContextId,
34 context_set: IndexSet<AgentContextKey>,
35 context_thread_ids: HashSet<ThreadId>,
36 context_text_thread_paths: HashSet<Arc<Path>>,
37}
38
39pub enum ContextStoreEvent {
40 ContextRemoved(AgentContextKey),
41}
42
43impl EventEmitter<ContextStoreEvent> for ContextStore {}
44
45impl ContextStore {
46 pub fn new(
47 project: WeakEntity<Project>,
48 thread_store: Option<WeakEntity<ThreadStore>>,
49 ) -> Self {
50 Self {
51 project,
52 thread_store,
53 next_context_id: ContextId::zero(),
54 context_set: IndexSet::default(),
55 context_thread_ids: HashSet::default(),
56 context_text_thread_paths: HashSet::default(),
57 }
58 }
59
60 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
61 self.context_set.iter().map(|entry| entry.as_ref())
62 }
63
64 pub fn clear(&mut self, cx: &mut Context<Self>) {
65 self.context_set.clear();
66 self.context_thread_ids.clear();
67 cx.notify();
68 }
69
70 pub fn new_context_for_thread(
71 &self,
72 thread: &Thread,
73 exclude_messages_from_id: Option<MessageId>,
74 ) -> Vec<AgentContextHandle> {
75 let existing_context = thread
76 .messages()
77 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
78 .flat_map(|message| {
79 message
80 .loaded_context
81 .contexts
82 .iter()
83 .map(|context| AgentContextKey(context.handle()))
84 })
85 .collect::<HashSet<_>>();
86 self.context_set
87 .iter()
88 .filter(|context| !existing_context.contains(context))
89 .map(|entry| entry.0.clone())
90 .collect::<Vec<_>>()
91 }
92
93 pub fn add_file_from_path(
94 &mut self,
95 project_path: ProjectPath,
96 remove_if_exists: bool,
97 cx: &mut Context<Self>,
98 ) -> Task<Result<Option<AgentContextHandle>>> {
99 let Some(project) = self.project.upgrade() else {
100 return Task::ready(Err(anyhow!("failed to read project")));
101 };
102
103 if is_image_file(&project, &project_path, cx) {
104 self.add_image_from_path(project_path, remove_if_exists, cx)
105 } else {
106 cx.spawn(async move |this, cx| {
107 let open_buffer_task = project.update(cx, |project, cx| {
108 project.open_buffer(project_path.clone(), cx)
109 })?;
110 let buffer = open_buffer_task.await?;
111 this.update(cx, |this, cx| {
112 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
113 })
114 })
115 }
116 }
117
118 pub fn add_file_from_buffer(
119 &mut self,
120 project_path: &ProjectPath,
121 buffer: Entity<Buffer>,
122 remove_if_exists: bool,
123 cx: &mut Context<Self>,
124 ) -> Option<AgentContextHandle> {
125 let context_id = self.next_context_id.post_inc();
126 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
127
128 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
129 if remove_if_exists {
130 self.remove_context(&context, cx);
131 None
132 } else {
133 Some(key.as_ref().clone())
134 }
135 } else if self.path_included_in_directory(project_path, cx).is_some() {
136 None
137 } else {
138 self.insert_context(context.clone(), cx);
139 Some(context)
140 }
141 }
142
143 pub fn add_directory(
144 &mut self,
145 project_path: &ProjectPath,
146 remove_if_exists: bool,
147 cx: &mut Context<Self>,
148 ) -> Result<Option<AgentContextHandle>> {
149 let project = self.project.upgrade().context("failed to read project")?;
150 let entry_id = project
151 .read(cx)
152 .entry_for_path(project_path, cx)
153 .map(|entry| entry.id)
154 .context("no entry found for directory context")?;
155
156 let context_id = self.next_context_id.post_inc();
157 let context = AgentContextHandle::Directory(DirectoryContextHandle {
158 entry_id,
159 context_id,
160 });
161
162 let context =
163 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
164 if remove_if_exists {
165 self.remove_context(&context, cx);
166 None
167 } else {
168 Some(existing.as_ref().clone())
169 }
170 } else {
171 self.insert_context(context.clone(), cx);
172 Some(context)
173 };
174
175 anyhow::Ok(context)
176 }
177
178 pub fn add_symbol(
179 &mut self,
180 buffer: Entity<Buffer>,
181 symbol: SharedString,
182 range: Range<Anchor>,
183 enclosing_range: Range<Anchor>,
184 remove_if_exists: bool,
185 cx: &mut Context<Self>,
186 ) -> (Option<AgentContextHandle>, bool) {
187 let context_id = self.next_context_id.post_inc();
188 let context = AgentContextHandle::Symbol(SymbolContextHandle {
189 buffer,
190 symbol,
191 range,
192 enclosing_range,
193 context_id,
194 });
195
196 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
197 let handle = if remove_if_exists {
198 self.remove_context(&context, cx);
199 None
200 } else {
201 Some(key.as_ref().clone())
202 };
203 return (handle, false);
204 }
205
206 let included = self.insert_context(context.clone(), cx);
207 (Some(context), included)
208 }
209
210 pub fn add_thread(
211 &mut self,
212 thread: Entity<Thread>,
213 remove_if_exists: bool,
214 cx: &mut Context<Self>,
215 ) -> Option<AgentContextHandle> {
216 let context_id = self.next_context_id.post_inc();
217 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
218
219 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
220 if remove_if_exists {
221 self.remove_context(&context, cx);
222 None
223 } else {
224 Some(existing.as_ref().clone())
225 }
226 } else {
227 self.insert_context(context.clone(), cx);
228 Some(context)
229 }
230 }
231
232 pub fn add_text_thread(
233 &mut self,
234 context: Entity<AssistantContext>,
235 remove_if_exists: bool,
236 cx: &mut Context<Self>,
237 ) -> Option<AgentContextHandle> {
238 let context_id = self.next_context_id.post_inc();
239 let context = AgentContextHandle::TextThread(TextThreadContextHandle {
240 context,
241 context_id,
242 });
243
244 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
245 if remove_if_exists {
246 self.remove_context(&context, cx);
247 None
248 } else {
249 Some(existing.as_ref().clone())
250 }
251 } else {
252 self.insert_context(context.clone(), cx);
253 Some(context)
254 }
255 }
256
257 pub fn add_rules(
258 &mut self,
259 prompt_id: UserPromptId,
260 remove_if_exists: bool,
261 cx: &mut Context<ContextStore>,
262 ) -> Option<AgentContextHandle> {
263 let context_id = self.next_context_id.post_inc();
264 let context = AgentContextHandle::Rules(RulesContextHandle {
265 prompt_id,
266 context_id,
267 });
268
269 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
270 if remove_if_exists {
271 self.remove_context(&context, cx);
272 None
273 } else {
274 Some(existing.as_ref().clone())
275 }
276 } else {
277 self.insert_context(context.clone(), cx);
278 Some(context)
279 }
280 }
281
282 pub fn add_fetched_url(
283 &mut self,
284 url: String,
285 text: impl Into<SharedString>,
286 cx: &mut Context<ContextStore>,
287 ) -> AgentContextHandle {
288 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
289 url: url.into(),
290 text: text.into(),
291 context_id: self.next_context_id.post_inc(),
292 });
293
294 self.insert_context(context.clone(), cx);
295 context
296 }
297
298 pub fn add_image_from_path(
299 &mut self,
300 project_path: ProjectPath,
301 remove_if_exists: bool,
302 cx: &mut Context<ContextStore>,
303 ) -> Task<Result<Option<AgentContextHandle>>> {
304 let project = self.project.clone();
305 cx.spawn(async move |this, cx| {
306 let open_image_task = project.update(cx, |project, cx| {
307 project.open_image(project_path.clone(), cx)
308 })?;
309 let image_item = open_image_task.await?;
310
311 this.update(cx, |this, cx| {
312 let item = image_item.read(cx);
313 this.insert_image(
314 Some(item.project_path(cx)),
315 Some(item.file.full_path(cx).to_string_lossy().into_owned()),
316 item.image.clone(),
317 remove_if_exists,
318 cx,
319 )
320 })
321 })
322 }
323
324 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
325 self.insert_image(None, None, image, false, cx);
326 }
327
328 fn insert_image(
329 &mut self,
330 project_path: Option<ProjectPath>,
331 full_path: Option<String>,
332 image: Arc<Image>,
333 remove_if_exists: bool,
334 cx: &mut Context<ContextStore>,
335 ) -> Option<AgentContextHandle> {
336 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
337 let context = AgentContextHandle::Image(ImageContext {
338 project_path,
339 full_path,
340 original_image: image,
341 image_task,
342 context_id: self.next_context_id.post_inc(),
343 });
344 if self.has_context(&context) && remove_if_exists {
345 self.remove_context(&context, cx);
346 return None;
347 }
348
349 self.insert_context(context.clone(), cx);
350 Some(context)
351 }
352
353 pub fn add_selection(
354 &mut self,
355 buffer: Entity<Buffer>,
356 range: Range<Anchor>,
357 cx: &mut Context<ContextStore>,
358 ) {
359 let context_id = self.next_context_id.post_inc();
360 let context = AgentContextHandle::Selection(SelectionContextHandle {
361 buffer,
362 range,
363 context_id,
364 });
365 self.insert_context(context, cx);
366 }
367
368 pub fn add_suggested_context(
369 &mut self,
370 suggested: &SuggestedContext,
371 cx: &mut Context<ContextStore>,
372 ) {
373 match suggested {
374 SuggestedContext::File {
375 buffer,
376 icon_path: _,
377 name: _,
378 } => {
379 if let Some(buffer) = buffer.upgrade() {
380 let context_id = self.next_context_id.post_inc();
381 self.insert_context(
382 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
383 cx,
384 );
385 };
386 }
387 SuggestedContext::Thread { thread, name: _ } => {
388 if let Some(thread) = thread.upgrade() {
389 let context_id = self.next_context_id.post_inc();
390 self.insert_context(
391 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
392 cx,
393 );
394 }
395 }
396 SuggestedContext::TextThread { context, name: _ } => {
397 if let Some(context) = context.upgrade() {
398 let context_id = self.next_context_id.post_inc();
399 self.insert_context(
400 AgentContextHandle::TextThread(TextThreadContextHandle {
401 context,
402 context_id,
403 }),
404 cx,
405 );
406 }
407 }
408 }
409 }
410
411 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
412 match &context {
413 AgentContextHandle::Thread(thread_context) => {
414 if let Some(thread_store) = self.thread_store.clone() {
415 thread_context.thread.update(cx, |thread, cx| {
416 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
417 });
418 self.context_thread_ids
419 .insert(thread_context.thread.read(cx).id().clone());
420 } else {
421 return false;
422 }
423 }
424 AgentContextHandle::TextThread(text_thread_context) => {
425 self.context_text_thread_paths
426 .extend(text_thread_context.context.read(cx).path().cloned());
427 }
428 _ => {}
429 }
430 let inserted = self.context_set.insert(AgentContextKey(context));
431 if inserted {
432 cx.notify();
433 }
434 inserted
435 }
436
437 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
438 if let Some((_, key)) = self
439 .context_set
440 .shift_remove_full(AgentContextKey::ref_cast(context))
441 {
442 match context {
443 AgentContextHandle::Thread(thread_context) => {
444 self.context_thread_ids
445 .remove(thread_context.thread.read(cx).id());
446 }
447 AgentContextHandle::TextThread(text_thread_context) => {
448 if let Some(path) = text_thread_context.context.read(cx).path() {
449 self.context_text_thread_paths.remove(path);
450 }
451 }
452 _ => {}
453 }
454 cx.emit(ContextStoreEvent::ContextRemoved(key));
455 cx.notify();
456 }
457 }
458
459 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
460 self.context_set
461 .contains(AgentContextKey::ref_cast(context))
462 }
463
464 /// Returns whether this file path is already included directly in the context, or if it will be
465 /// included in the context via a directory.
466 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
467 let project = self.project.upgrade()?.read(cx);
468 self.context().find_map(|context| match context {
469 AgentContextHandle::File(file_context) => {
470 FileInclusion::check_file(file_context, path, cx)
471 }
472 AgentContextHandle::Image(image_context) => {
473 FileInclusion::check_image(image_context, path)
474 }
475 AgentContextHandle::Directory(directory_context) => {
476 FileInclusion::check_directory(directory_context, path, project, cx)
477 }
478 _ => None,
479 })
480 }
481
482 pub fn path_included_in_directory(
483 &self,
484 path: &ProjectPath,
485 cx: &App,
486 ) -> Option<FileInclusion> {
487 let project = self.project.upgrade()?.read(cx);
488 self.context().find_map(|context| match context {
489 AgentContextHandle::Directory(directory_context) => {
490 FileInclusion::check_directory(directory_context, path, project, cx)
491 }
492 _ => None,
493 })
494 }
495
496 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
497 self.context().any(|context| match context {
498 AgentContextHandle::Symbol(context) => {
499 if context.symbol != symbol.name {
500 return false;
501 }
502 let buffer = context.buffer.read(cx);
503 let Some(context_path) = buffer.project_path(cx) else {
504 return false;
505 };
506 if symbol.path != SymbolLocation::InProject(context_path) {
507 return false;
508 }
509 let context_range = context.range.to_point_utf16(&buffer.snapshot());
510 context_range.start == symbol.range.start.0
511 && context_range.end == symbol.range.end.0
512 }
513 _ => false,
514 })
515 }
516
517 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
518 self.context_thread_ids.contains(thread_id)
519 }
520
521 pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
522 self.context_text_thread_paths.contains(path)
523 }
524
525 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
526 self.context_set
527 .contains(&RulesContextHandle::lookup_key(prompt_id))
528 }
529
530 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
531 self.context_set
532 .contains(&FetchedUrlContext::lookup_key(url.into()))
533 }
534
535 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
536 self.context_set
537 .get(&FetchedUrlContext::lookup_key(url))
538 .map(|key| key.as_ref().clone())
539 }
540
541 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
542 self.context()
543 .filter_map(|context| match context {
544 AgentContextHandle::File(file) => {
545 let buffer = file.buffer.read(cx);
546 buffer.project_path(cx)
547 }
548 AgentContextHandle::Directory(_)
549 | AgentContextHandle::Symbol(_)
550 | AgentContextHandle::Selection(_)
551 | AgentContextHandle::FetchedUrl(_)
552 | AgentContextHandle::Thread(_)
553 | AgentContextHandle::TextThread(_)
554 | AgentContextHandle::Rules(_)
555 | AgentContextHandle::Image(_) => None,
556 })
557 .collect()
558 }
559
560 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
561 &self.context_thread_ids
562 }
563}
564
565#[derive(Clone)]
566pub enum SuggestedContext {
567 File {
568 name: SharedString,
569 icon_path: Option<SharedString>,
570 buffer: WeakEntity<Buffer>,
571 },
572 Thread {
573 name: SharedString,
574 thread: WeakEntity<Thread>,
575 },
576 TextThread {
577 name: SharedString,
578 context: WeakEntity<AssistantContext>,
579 },
580}
581
582impl SuggestedContext {
583 pub fn name(&self) -> &SharedString {
584 match self {
585 Self::File { name, .. } => name,
586 Self::Thread { name, .. } => name,
587 Self::TextThread { name, .. } => name,
588 }
589 }
590
591 pub fn icon_path(&self) -> Option<SharedString> {
592 match self {
593 Self::File { icon_path, .. } => icon_path.clone(),
594 Self::Thread { .. } => None,
595 Self::TextThread { .. } => None,
596 }
597 }
598
599 pub fn kind(&self) -> ContextKind {
600 match self {
601 Self::File { .. } => ContextKind::File,
602 Self::Thread { .. } => ContextKind::Thread,
603 Self::TextThread { .. } => ContextKind::TextThread,
604 }
605 }
606}
607
608pub enum FileInclusion {
609 Direct,
610 InDirectory { full_path: PathBuf },
611}
612
613impl FileInclusion {
614 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
615 let file_path = file_context.buffer.read(cx).project_path(cx)?;
616 if path == &file_path {
617 Some(FileInclusion::Direct)
618 } else {
619 None
620 }
621 }
622
623 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
624 let image_path = image_context.project_path.as_ref()?;
625 if path == image_path {
626 Some(FileInclusion::Direct)
627 } else {
628 None
629 }
630 }
631
632 fn check_directory(
633 directory_context: &DirectoryContextHandle,
634 path: &ProjectPath,
635 project: &Project,
636 cx: &App,
637 ) -> Option<Self> {
638 let worktree = project
639 .worktree_for_entry(directory_context.entry_id, cx)?
640 .read(cx);
641 let entry = worktree.entry_for_id(directory_context.entry_id)?;
642 let directory_path = ProjectPath {
643 worktree_id: worktree.id(),
644 path: entry.path.clone(),
645 };
646 if path.starts_with(&directory_path) {
647 if path == &directory_path {
648 Some(FileInclusion::Direct)
649 } else {
650 Some(FileInclusion::InDirectory {
651 full_path: worktree.full_path(&entry.path),
652 })
653 }
654 } else {
655 None
656 }
657 }
658}