1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::{self, FutureExt};
8use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
9use language::Buffer;
10use language_model::LanguageModelImage;
11use project::image_store::is_image_file;
12use project::{Project, ProjectItem, ProjectPath, Symbol};
13use prompt_store::UserPromptId;
14use ref_cast::RefCast as _;
15use text::{Anchor, OffsetRangeExt};
16
17use crate::ThreadStore;
18use crate::context::{
19 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
20 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
21 SymbolContextHandle, ThreadContextHandle,
22};
23use crate::context_strip::SuggestedContext;
24use crate::thread::{MessageId, Thread, ThreadId};
25
26pub struct ContextStore {
27 project: WeakEntity<Project>,
28 thread_store: Option<WeakEntity<ThreadStore>>,
29 next_context_id: ContextId,
30 context_set: IndexSet<AgentContextKey>,
31 context_thread_ids: HashSet<ThreadId>,
32}
33
34impl ContextStore {
35 pub fn new(
36 project: WeakEntity<Project>,
37 thread_store: Option<WeakEntity<ThreadStore>>,
38 ) -> Self {
39 Self {
40 project,
41 thread_store,
42 next_context_id: ContextId::zero(),
43 context_set: IndexSet::default(),
44 context_thread_ids: HashSet::default(),
45 }
46 }
47
48 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
49 self.context_set.iter().map(|entry| entry.as_ref())
50 }
51
52 pub fn clear(&mut self) {
53 self.context_set.clear();
54 self.context_thread_ids.clear();
55 }
56
57 pub fn new_context_for_thread(
58 &self,
59 thread: &Thread,
60 exclude_messages_from_id: Option<MessageId>,
61 ) -> Vec<AgentContextHandle> {
62 let existing_context = thread
63 .messages()
64 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
65 .flat_map(|message| {
66 message
67 .loaded_context
68 .contexts
69 .iter()
70 .map(|context| AgentContextKey(context.handle()))
71 })
72 .collect::<HashSet<_>>();
73 self.context_set
74 .iter()
75 .filter(|context| !existing_context.contains(context))
76 .map(|entry| entry.0.clone())
77 .collect::<Vec<_>>()
78 }
79
80 pub fn add_file_from_path(
81 &mut self,
82 project_path: ProjectPath,
83 remove_if_exists: bool,
84 cx: &mut Context<Self>,
85 ) -> Task<Result<()>> {
86 let Some(project) = self.project.upgrade() else {
87 return Task::ready(Err(anyhow!("failed to read project")));
88 };
89
90 if is_image_file(&project, &project_path, cx) {
91 self.add_image_from_path(project_path, remove_if_exists, cx)
92 } else {
93 cx.spawn(async move |this, cx| {
94 let open_buffer_task = project.update(cx, |project, cx| {
95 project.open_buffer(project_path.clone(), cx)
96 })?;
97 let buffer = open_buffer_task.await?;
98 this.update(cx, |this, cx| {
99 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
100 })
101 })
102 }
103 }
104
105 pub fn add_file_from_buffer(
106 &mut self,
107 project_path: &ProjectPath,
108 buffer: Entity<Buffer>,
109 remove_if_exists: bool,
110 cx: &mut Context<Self>,
111 ) {
112 let context_id = self.next_context_id.post_inc();
113 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
114
115 let already_included = if self.has_context(&context) {
116 if remove_if_exists {
117 self.remove_context(&context, cx);
118 }
119 true
120 } else {
121 self.path_included_in_directory(project_path, cx).is_some()
122 };
123
124 if !already_included {
125 self.insert_context(context, cx);
126 }
127 }
128
129 pub fn add_directory(
130 &mut self,
131 project_path: &ProjectPath,
132 remove_if_exists: bool,
133 cx: &mut Context<Self>,
134 ) -> Result<()> {
135 let Some(project) = self.project.upgrade() else {
136 return Err(anyhow!("failed to read project"));
137 };
138
139 let Some(entry_id) = project
140 .read(cx)
141 .entry_for_path(project_path, cx)
142 .map(|entry| entry.id)
143 else {
144 return Err(anyhow!("no entry found for directory context"));
145 };
146
147 let context_id = self.next_context_id.post_inc();
148 let context = AgentContextHandle::Directory(DirectoryContextHandle {
149 entry_id,
150 context_id,
151 });
152
153 if self.has_context(&context) {
154 if remove_if_exists {
155 self.remove_context(&context, cx);
156 }
157 } else if self.path_included_in_directory(project_path, cx).is_none() {
158 self.insert_context(context, cx);
159 }
160
161 anyhow::Ok(())
162 }
163
164 pub fn add_symbol(
165 &mut self,
166 buffer: Entity<Buffer>,
167 symbol: SharedString,
168 range: Range<Anchor>,
169 enclosing_range: Range<Anchor>,
170 remove_if_exists: bool,
171 cx: &mut Context<Self>,
172 ) -> bool {
173 let context_id = self.next_context_id.post_inc();
174 let context = AgentContextHandle::Symbol(SymbolContextHandle {
175 buffer,
176 symbol,
177 range,
178 enclosing_range,
179 context_id,
180 });
181
182 if self.has_context(&context) {
183 if remove_if_exists {
184 self.remove_context(&context, cx);
185 }
186 return false;
187 }
188
189 self.insert_context(context, cx)
190 }
191
192 pub fn add_thread(
193 &mut self,
194 thread: Entity<Thread>,
195 remove_if_exists: bool,
196 cx: &mut Context<Self>,
197 ) {
198 let context_id = self.next_context_id.post_inc();
199 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
200
201 if self.has_context(&context) {
202 if remove_if_exists {
203 self.remove_context(&context, cx);
204 }
205 } else {
206 self.insert_context(context, cx);
207 }
208 }
209
210 pub fn add_rules(
211 &mut self,
212 prompt_id: UserPromptId,
213 remove_if_exists: bool,
214 cx: &mut Context<ContextStore>,
215 ) {
216 let context_id = self.next_context_id.post_inc();
217 let context = AgentContextHandle::Rules(RulesContextHandle {
218 prompt_id,
219 context_id,
220 });
221
222 if self.has_context(&context) {
223 if remove_if_exists {
224 self.remove_context(&context, cx);
225 }
226 } else {
227 self.insert_context(context, cx);
228 }
229 }
230
231 pub fn add_fetched_url(
232 &mut self,
233 url: String,
234 text: impl Into<SharedString>,
235 cx: &mut Context<ContextStore>,
236 ) {
237 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
238 url: url.into(),
239 text: text.into(),
240 context_id: self.next_context_id.post_inc(),
241 });
242
243 self.insert_context(context, cx);
244 }
245
246 pub fn add_image_from_path(
247 &mut self,
248 project_path: ProjectPath,
249 remove_if_exists: bool,
250 cx: &mut Context<ContextStore>,
251 ) -> Task<Result<()>> {
252 let project = self.project.clone();
253 cx.spawn(async move |this, cx| {
254 let open_image_task = project.update(cx, |project, cx| {
255 project.open_image(project_path.clone(), cx)
256 })?;
257 let image_item = open_image_task.await?;
258 let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
259 this.update(cx, |this, cx| {
260 this.insert_image(
261 Some(image_item.read(cx).project_path(cx)),
262 image,
263 remove_if_exists,
264 cx,
265 );
266 })
267 })
268 }
269
270 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
271 self.insert_image(None, image, false, cx);
272 }
273
274 fn insert_image(
275 &mut self,
276 project_path: Option<ProjectPath>,
277 image: Arc<Image>,
278 remove_if_exists: bool,
279 cx: &mut Context<ContextStore>,
280 ) {
281 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
282 let context = AgentContextHandle::Image(ImageContext {
283 project_path,
284 original_image: image,
285 image_task,
286 context_id: self.next_context_id.post_inc(),
287 });
288 if self.has_context(&context) {
289 if remove_if_exists {
290 self.remove_context(&context, cx);
291 return;
292 }
293 }
294
295 self.insert_context(context, cx);
296 }
297
298 pub fn add_selection(
299 &mut self,
300 buffer: Entity<Buffer>,
301 range: Range<Anchor>,
302 cx: &mut Context<ContextStore>,
303 ) {
304 let context_id = self.next_context_id.post_inc();
305 let context = AgentContextHandle::Selection(SelectionContextHandle {
306 buffer,
307 range,
308 context_id,
309 });
310 self.insert_context(context, cx);
311 }
312
313 pub fn add_suggested_context(
314 &mut self,
315 suggested: &SuggestedContext,
316 cx: &mut Context<ContextStore>,
317 ) {
318 match suggested {
319 SuggestedContext::File {
320 buffer,
321 icon_path: _,
322 name: _,
323 } => {
324 if let Some(buffer) = buffer.upgrade() {
325 let context_id = self.next_context_id.post_inc();
326 self.insert_context(
327 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
328 cx,
329 );
330 };
331 }
332 SuggestedContext::Thread { thread, name: _ } => {
333 if let Some(thread) = thread.upgrade() {
334 let context_id = self.next_context_id.post_inc();
335 self.insert_context(
336 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
337 cx,
338 );
339 }
340 }
341 }
342 }
343
344 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
345 match &context {
346 AgentContextHandle::Thread(thread_context) => {
347 if let Some(thread_store) = self.thread_store.clone() {
348 thread_context.thread.update(cx, |thread, cx| {
349 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
350 });
351 self.context_thread_ids
352 .insert(thread_context.thread.read(cx).id().clone());
353 } else {
354 return false;
355 }
356 }
357 _ => {}
358 }
359 let inserted = self.context_set.insert(AgentContextKey(context));
360 if inserted {
361 cx.notify();
362 }
363 inserted
364 }
365
366 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
367 if self
368 .context_set
369 .shift_remove(AgentContextKey::ref_cast(context))
370 {
371 match context {
372 AgentContextHandle::Thread(thread_context) => {
373 self.context_thread_ids
374 .remove(thread_context.thread.read(cx).id());
375 }
376 _ => {}
377 }
378 cx.notify();
379 }
380 }
381
382 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
383 self.context_set
384 .contains(AgentContextKey::ref_cast(context))
385 }
386
387 /// Returns whether this file path is already included directly in the context, or if it will be
388 /// included in the context via a directory.
389 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
390 let project = self.project.upgrade()?.read(cx);
391 self.context().find_map(|context| match context {
392 AgentContextHandle::File(file_context) => {
393 FileInclusion::check_file(file_context, path, cx)
394 }
395 AgentContextHandle::Image(image_context) => {
396 FileInclusion::check_image(image_context, path)
397 }
398 AgentContextHandle::Directory(directory_context) => {
399 FileInclusion::check_directory(directory_context, path, project, cx)
400 }
401 _ => None,
402 })
403 }
404
405 pub fn path_included_in_directory(
406 &self,
407 path: &ProjectPath,
408 cx: &App,
409 ) -> Option<FileInclusion> {
410 let project = self.project.upgrade()?.read(cx);
411 self.context().find_map(|context| match context {
412 AgentContextHandle::Directory(directory_context) => {
413 FileInclusion::check_directory(directory_context, path, project, cx)
414 }
415 _ => None,
416 })
417 }
418
419 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
420 self.context().any(|context| match context {
421 AgentContextHandle::Symbol(context) => {
422 if context.symbol != symbol.name {
423 return false;
424 }
425 let buffer = context.buffer.read(cx);
426 let Some(context_path) = buffer.project_path(cx) else {
427 return false;
428 };
429 if context_path != symbol.path {
430 return false;
431 }
432 let context_range = context.range.to_point_utf16(&buffer.snapshot());
433 context_range.start == symbol.range.start.0
434 && context_range.end == symbol.range.end.0
435 }
436 _ => false,
437 })
438 }
439
440 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
441 self.context_thread_ids.contains(thread_id)
442 }
443
444 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
445 self.context_set
446 .contains(&RulesContextHandle::lookup_key(prompt_id))
447 }
448
449 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
450 self.context_set
451 .contains(&FetchedUrlContext::lookup_key(url.into()))
452 }
453
454 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
455 self.context()
456 .filter_map(|context| match context {
457 AgentContextHandle::File(file) => {
458 let buffer = file.buffer.read(cx);
459 buffer.project_path(cx)
460 }
461 AgentContextHandle::Directory(_)
462 | AgentContextHandle::Symbol(_)
463 | AgentContextHandle::Selection(_)
464 | AgentContextHandle::FetchedUrl(_)
465 | AgentContextHandle::Thread(_)
466 | AgentContextHandle::Rules(_)
467 | AgentContextHandle::Image(_) => None,
468 })
469 .collect()
470 }
471
472 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
473 &self.context_thread_ids
474 }
475}
476
477pub enum FileInclusion {
478 Direct,
479 InDirectory { full_path: PathBuf },
480}
481
482impl FileInclusion {
483 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
484 let file_path = file_context.buffer.read(cx).project_path(cx)?;
485 if path == &file_path {
486 Some(FileInclusion::Direct)
487 } else {
488 None
489 }
490 }
491
492 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
493 let image_path = image_context.project_path.as_ref()?;
494 if path == image_path {
495 Some(FileInclusion::Direct)
496 } else {
497 None
498 }
499 }
500
501 fn check_directory(
502 directory_context: &DirectoryContextHandle,
503 path: &ProjectPath,
504 project: &Project,
505 cx: &App,
506 ) -> Option<Self> {
507 let worktree = project
508 .worktree_for_entry(directory_context.entry_id, cx)?
509 .read(cx);
510 let entry = worktree.entry_for_id(directory_context.entry_id)?;
511 let directory_path = ProjectPath {
512 worktree_id: worktree.id(),
513 path: entry.path.clone(),
514 };
515 if path.starts_with(&directory_path) {
516 if path == &directory_path {
517 Some(FileInclusion::Direct)
518 } else {
519 Some(FileInclusion::InDirectory {
520 full_path: worktree.full_path(&entry.path),
521 })
522 }
523 } else {
524 None
525 }
526 }
527}