1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::{self, FutureExt};
8use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
9use language::Buffer;
10use language_model::LanguageModelImage;
11use project::image_store::is_image_file;
12use project::{Project, ProjectItem, ProjectPath, Symbol};
13use prompt_store::UserPromptId;
14use ref_cast::RefCast as _;
15use text::{Anchor, OffsetRangeExt};
16
17use crate::ThreadStore;
18use crate::context::{
19 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
20 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
21 SymbolContextHandle, ThreadContextHandle,
22};
23use crate::context_strip::SuggestedContext;
24use crate::thread::{MessageId, Thread, ThreadId};
25
26pub struct ContextStore {
27 project: WeakEntity<Project>,
28 thread_store: Option<WeakEntity<ThreadStore>>,
29 next_context_id: ContextId,
30 context_set: IndexSet<AgentContextKey>,
31 context_thread_ids: HashSet<ThreadId>,
32}
33
34pub enum ContextStoreEvent {
35 ContextRemoved(AgentContextKey),
36}
37
38impl EventEmitter<ContextStoreEvent> for ContextStore {}
39
40impl ContextStore {
41 pub fn new(
42 project: WeakEntity<Project>,
43 thread_store: Option<WeakEntity<ThreadStore>>,
44 ) -> Self {
45 Self {
46 project,
47 thread_store,
48 next_context_id: ContextId::zero(),
49 context_set: IndexSet::default(),
50 context_thread_ids: HashSet::default(),
51 }
52 }
53
54 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
55 self.context_set.iter().map(|entry| entry.as_ref())
56 }
57
58 pub fn clear(&mut self) {
59 self.context_set.clear();
60 self.context_thread_ids.clear();
61 }
62
63 pub fn new_context_for_thread(
64 &self,
65 thread: &Thread,
66 exclude_messages_from_id: Option<MessageId>,
67 ) -> Vec<AgentContextHandle> {
68 let existing_context = thread
69 .messages()
70 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
71 .flat_map(|message| {
72 message
73 .loaded_context
74 .contexts
75 .iter()
76 .map(|context| AgentContextKey(context.handle()))
77 })
78 .collect::<HashSet<_>>();
79 self.context_set
80 .iter()
81 .filter(|context| !existing_context.contains(context))
82 .map(|entry| entry.0.clone())
83 .collect::<Vec<_>>()
84 }
85
86 pub fn add_file_from_path(
87 &mut self,
88 project_path: ProjectPath,
89 remove_if_exists: bool,
90 cx: &mut Context<Self>,
91 ) -> Task<Result<Option<AgentContextHandle>>> {
92 let Some(project) = self.project.upgrade() else {
93 return Task::ready(Err(anyhow!("failed to read project")));
94 };
95
96 if is_image_file(&project, &project_path, cx) {
97 self.add_image_from_path(project_path, remove_if_exists, cx)
98 } else {
99 cx.spawn(async move |this, cx| {
100 let open_buffer_task = project.update(cx, |project, cx| {
101 project.open_buffer(project_path.clone(), cx)
102 })?;
103 let buffer = open_buffer_task.await?;
104 this.update(cx, |this, cx| {
105 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
106 })
107 })
108 }
109 }
110
111 pub fn add_file_from_buffer(
112 &mut self,
113 project_path: &ProjectPath,
114 buffer: Entity<Buffer>,
115 remove_if_exists: bool,
116 cx: &mut Context<Self>,
117 ) -> Option<AgentContextHandle> {
118 let context_id = self.next_context_id.post_inc();
119 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
120
121 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
122 if remove_if_exists {
123 self.remove_context(&context, cx);
124 None
125 } else {
126 Some(key.as_ref().clone())
127 }
128 } else if self.path_included_in_directory(project_path, cx).is_some() {
129 None
130 } else {
131 self.insert_context(context.clone(), cx);
132 Some(context)
133 }
134 }
135
136 pub fn add_directory(
137 &mut self,
138 project_path: &ProjectPath,
139 remove_if_exists: bool,
140 cx: &mut Context<Self>,
141 ) -> Result<Option<AgentContextHandle>> {
142 let Some(project) = self.project.upgrade() else {
143 return Err(anyhow!("failed to read project"));
144 };
145
146 let Some(entry_id) = project
147 .read(cx)
148 .entry_for_path(project_path, cx)
149 .map(|entry| entry.id)
150 else {
151 return Err(anyhow!("no entry found for directory context"));
152 };
153
154 let context_id = self.next_context_id.post_inc();
155 let context = AgentContextHandle::Directory(DirectoryContextHandle {
156 entry_id,
157 context_id,
158 });
159
160 let context =
161 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
162 if remove_if_exists {
163 self.remove_context(&context, cx);
164 None
165 } else {
166 Some(existing.as_ref().clone())
167 }
168 } else {
169 self.insert_context(context.clone(), cx);
170 Some(context)
171 };
172
173 anyhow::Ok(context)
174 }
175
176 pub fn add_symbol(
177 &mut self,
178 buffer: Entity<Buffer>,
179 symbol: SharedString,
180 range: Range<Anchor>,
181 enclosing_range: Range<Anchor>,
182 remove_if_exists: bool,
183 cx: &mut Context<Self>,
184 ) -> (Option<AgentContextHandle>, bool) {
185 let context_id = self.next_context_id.post_inc();
186 let context = AgentContextHandle::Symbol(SymbolContextHandle {
187 buffer,
188 symbol,
189 range,
190 enclosing_range,
191 context_id,
192 });
193
194 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
195 let handle = if remove_if_exists {
196 self.remove_context(&context, cx);
197 None
198 } else {
199 Some(key.as_ref().clone())
200 };
201 return (handle, false);
202 }
203
204 let included = self.insert_context(context.clone(), cx);
205 (Some(context), included)
206 }
207
208 pub fn add_thread(
209 &mut self,
210 thread: Entity<Thread>,
211 remove_if_exists: bool,
212 cx: &mut Context<Self>,
213 ) -> Option<AgentContextHandle> {
214 let context_id = self.next_context_id.post_inc();
215 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
216
217 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
218 if remove_if_exists {
219 self.remove_context(&context, cx);
220 None
221 } else {
222 Some(existing.as_ref().clone())
223 }
224 } else {
225 self.insert_context(context.clone(), cx);
226 Some(context)
227 }
228 }
229
230 pub fn add_rules(
231 &mut self,
232 prompt_id: UserPromptId,
233 remove_if_exists: bool,
234 cx: &mut Context<ContextStore>,
235 ) -> Option<AgentContextHandle> {
236 let context_id = self.next_context_id.post_inc();
237 let context = AgentContextHandle::Rules(RulesContextHandle {
238 prompt_id,
239 context_id,
240 });
241
242 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
243 if remove_if_exists {
244 self.remove_context(&context, cx);
245 None
246 } else {
247 Some(existing.as_ref().clone())
248 }
249 } else {
250 self.insert_context(context.clone(), cx);
251 Some(context)
252 }
253 }
254
255 pub fn add_fetched_url(
256 &mut self,
257 url: String,
258 text: impl Into<SharedString>,
259 cx: &mut Context<ContextStore>,
260 ) -> AgentContextHandle {
261 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
262 url: url.into(),
263 text: text.into(),
264 context_id: self.next_context_id.post_inc(),
265 });
266
267 self.insert_context(context.clone(), cx);
268 context
269 }
270
271 pub fn add_image_from_path(
272 &mut self,
273 project_path: ProjectPath,
274 remove_if_exists: bool,
275 cx: &mut Context<ContextStore>,
276 ) -> Task<Result<Option<AgentContextHandle>>> {
277 let project = self.project.clone();
278 cx.spawn(async move |this, cx| {
279 let open_image_task = project.update(cx, |project, cx| {
280 project.open_image(project_path.clone(), cx)
281 })?;
282 let image_item = open_image_task.await?;
283 let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
284 this.update(cx, |this, cx| {
285 this.insert_image(
286 Some(image_item.read(cx).project_path(cx)),
287 image,
288 remove_if_exists,
289 cx,
290 )
291 })
292 })
293 }
294
295 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
296 self.insert_image(None, image, false, cx);
297 }
298
299 fn insert_image(
300 &mut self,
301 project_path: Option<ProjectPath>,
302 image: Arc<Image>,
303 remove_if_exists: bool,
304 cx: &mut Context<ContextStore>,
305 ) -> Option<AgentContextHandle> {
306 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
307 let context = AgentContextHandle::Image(ImageContext {
308 project_path,
309 original_image: image,
310 image_task,
311 context_id: self.next_context_id.post_inc(),
312 });
313 if self.has_context(&context) {
314 if remove_if_exists {
315 self.remove_context(&context, cx);
316 return None;
317 }
318 }
319
320 self.insert_context(context.clone(), cx);
321 Some(context)
322 }
323
324 pub fn add_selection(
325 &mut self,
326 buffer: Entity<Buffer>,
327 range: Range<Anchor>,
328 cx: &mut Context<ContextStore>,
329 ) {
330 let context_id = self.next_context_id.post_inc();
331 let context = AgentContextHandle::Selection(SelectionContextHandle {
332 buffer,
333 range,
334 context_id,
335 });
336 self.insert_context(context, cx);
337 }
338
339 pub fn add_suggested_context(
340 &mut self,
341 suggested: &SuggestedContext,
342 cx: &mut Context<ContextStore>,
343 ) {
344 match suggested {
345 SuggestedContext::File {
346 buffer,
347 icon_path: _,
348 name: _,
349 } => {
350 if let Some(buffer) = buffer.upgrade() {
351 let context_id = self.next_context_id.post_inc();
352 self.insert_context(
353 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
354 cx,
355 );
356 };
357 }
358 SuggestedContext::Thread { thread, name: _ } => {
359 if let Some(thread) = thread.upgrade() {
360 let context_id = self.next_context_id.post_inc();
361 self.insert_context(
362 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
363 cx,
364 );
365 }
366 }
367 }
368 }
369
370 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
371 match &context {
372 AgentContextHandle::Thread(thread_context) => {
373 if let Some(thread_store) = self.thread_store.clone() {
374 thread_context.thread.update(cx, |thread, cx| {
375 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
376 });
377 self.context_thread_ids
378 .insert(thread_context.thread.read(cx).id().clone());
379 } else {
380 return false;
381 }
382 }
383 _ => {}
384 }
385 let inserted = self.context_set.insert(AgentContextKey(context));
386 if inserted {
387 cx.notify();
388 }
389 inserted
390 }
391
392 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
393 if let Some((_, key)) = self
394 .context_set
395 .shift_remove_full(AgentContextKey::ref_cast(context))
396 {
397 match context {
398 AgentContextHandle::Thread(thread_context) => {
399 self.context_thread_ids
400 .remove(thread_context.thread.read(cx).id());
401 }
402 _ => {}
403 }
404 cx.emit(ContextStoreEvent::ContextRemoved(key));
405 cx.notify();
406 }
407 }
408
409 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
410 self.context_set
411 .contains(AgentContextKey::ref_cast(context))
412 }
413
414 /// Returns whether this file path is already included directly in the context, or if it will be
415 /// included in the context via a directory.
416 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
417 let project = self.project.upgrade()?.read(cx);
418 self.context().find_map(|context| match context {
419 AgentContextHandle::File(file_context) => {
420 FileInclusion::check_file(file_context, path, cx)
421 }
422 AgentContextHandle::Image(image_context) => {
423 FileInclusion::check_image(image_context, path)
424 }
425 AgentContextHandle::Directory(directory_context) => {
426 FileInclusion::check_directory(directory_context, path, project, cx)
427 }
428 _ => None,
429 })
430 }
431
432 pub fn path_included_in_directory(
433 &self,
434 path: &ProjectPath,
435 cx: &App,
436 ) -> Option<FileInclusion> {
437 let project = self.project.upgrade()?.read(cx);
438 self.context().find_map(|context| match context {
439 AgentContextHandle::Directory(directory_context) => {
440 FileInclusion::check_directory(directory_context, path, project, cx)
441 }
442 _ => None,
443 })
444 }
445
446 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
447 self.context().any(|context| match context {
448 AgentContextHandle::Symbol(context) => {
449 if context.symbol != symbol.name {
450 return false;
451 }
452 let buffer = context.buffer.read(cx);
453 let Some(context_path) = buffer.project_path(cx) else {
454 return false;
455 };
456 if context_path != symbol.path {
457 return false;
458 }
459 let context_range = context.range.to_point_utf16(&buffer.snapshot());
460 context_range.start == symbol.range.start.0
461 && context_range.end == symbol.range.end.0
462 }
463 _ => false,
464 })
465 }
466
467 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
468 self.context_thread_ids.contains(thread_id)
469 }
470
471 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
472 self.context_set
473 .contains(&RulesContextHandle::lookup_key(prompt_id))
474 }
475
476 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
477 self.context_set
478 .contains(&FetchedUrlContext::lookup_key(url.into()))
479 }
480
481 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
482 self.context_set
483 .get(&FetchedUrlContext::lookup_key(url))
484 .map(|key| key.as_ref().clone())
485 }
486
487 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
488 self.context()
489 .filter_map(|context| match context {
490 AgentContextHandle::File(file) => {
491 let buffer = file.buffer.read(cx);
492 buffer.project_path(cx)
493 }
494 AgentContextHandle::Directory(_)
495 | AgentContextHandle::Symbol(_)
496 | AgentContextHandle::Selection(_)
497 | AgentContextHandle::FetchedUrl(_)
498 | AgentContextHandle::Thread(_)
499 | AgentContextHandle::Rules(_)
500 | AgentContextHandle::Image(_) => None,
501 })
502 .collect()
503 }
504
505 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
506 &self.context_thread_ids
507 }
508}
509
510pub enum FileInclusion {
511 Direct,
512 InDirectory { full_path: PathBuf },
513}
514
515impl FileInclusion {
516 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
517 let file_path = file_context.buffer.read(cx).project_path(cx)?;
518 if path == &file_path {
519 Some(FileInclusion::Direct)
520 } else {
521 None
522 }
523 }
524
525 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
526 let image_path = image_context.project_path.as_ref()?;
527 if path == image_path {
528 Some(FileInclusion::Direct)
529 } else {
530 None
531 }
532 }
533
534 fn check_directory(
535 directory_context: &DirectoryContextHandle,
536 path: &ProjectPath,
537 project: &Project,
538 cx: &App,
539 ) -> Option<Self> {
540 let worktree = project
541 .worktree_for_entry(directory_context.entry_id, cx)?
542 .read(cx);
543 let entry = worktree.entry_for_id(directory_context.entry_id)?;
544 let directory_path = ProjectPath {
545 worktree_id: worktree.id(),
546 path: entry.path.clone(),
547 };
548 if path.starts_with(&directory_path) {
549 if path == &directory_path {
550 Some(FileInclusion::Direct)
551 } else {
552 Some(FileInclusion::InDirectory {
553 full_path: worktree.full_path(&entry.path),
554 })
555 }
556 } else {
557 None
558 }
559 }
560}