1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4};
5
6use anyhow::Context as _;
7use collections::HashMap;
8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
9use language::{
10 ContextProvider as _, LanguageToolchainStore, Location,
11 proto::{deserialize_anchor, serialize_anchor},
12};
13use rpc::{AnyProtoClient, TypedEnvelope, proto};
14use settings::{InvalidSettingsError, SettingsLocation};
15use task::{TaskContext, TaskVariables, VariableName};
16use text::{BufferId, OffsetRangeExt};
17use util::ResultExt;
18
19use crate::{
20 BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
21 worktree_store::WorktreeStore,
22};
23
24// platform-dependent warning
25pub enum TaskStore {
26 Functional(StoreState),
27 Noop,
28}
29
30pub struct StoreState {
31 mode: StoreMode,
32 task_inventory: Entity<Inventory>,
33 buffer_store: WeakEntity<BufferStore>,
34 worktree_store: Entity<WorktreeStore>,
35 toolchain_store: Arc<dyn LanguageToolchainStore>,
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Entity<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51#[derive(Debug)]
52pub enum TaskSettingsLocation<'a> {
53 Global(&'a Path),
54 Worktree(SettingsLocation<'a>),
55}
56
57impl TaskStore {
58 pub fn init(client: Option<&AnyProtoClient>) {
59 if let Some(client) = client {
60 client.add_entity_request_handler(Self::handle_task_context_for_location);
61 }
62 }
63
64 async fn handle_task_context_for_location(
65 store: Entity<Self>,
66 envelope: TypedEnvelope<proto::TaskContextForLocation>,
67 mut cx: AsyncApp,
68 ) -> anyhow::Result<proto::TaskContext> {
69 let location = envelope
70 .payload
71 .location
72 .context("no location given for task context handling")?;
73 let (buffer_store, is_remote) = store.read_with(&mut cx, |store, _| {
74 Ok(match store {
75 TaskStore::Functional(state) => (
76 state.buffer_store.clone(),
77 match &state.mode {
78 StoreMode::Local { .. } => false,
79 StoreMode::Remote { .. } => true,
80 },
81 ),
82 TaskStore::Noop => {
83 anyhow::bail!("empty task store cannot handle task context requests")
84 }
85 })
86 })??;
87 let buffer_store = buffer_store
88 .upgrade()
89 .context("no buffer store when handling task context request")?;
90
91 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
92 format!(
93 "cannot handle task context request for invalid buffer id: {}",
94 location.buffer_id
95 )
96 })?;
97
98 let start = location
99 .start
100 .and_then(deserialize_anchor)
101 .context("missing task context location start")?;
102 let end = location
103 .end
104 .and_then(deserialize_anchor)
105 .context("missing task context location end")?;
106 let buffer = buffer_store
107 .update(&mut cx, |buffer_store, cx| {
108 if is_remote {
109 buffer_store.wait_for_remote_buffer(buffer_id, cx)
110 } else {
111 Task::ready(
112 buffer_store
113 .get(buffer_id)
114 .with_context(|| format!("no local buffer with id {buffer_id}")),
115 )
116 }
117 })?
118 .await?;
119
120 let location = Location {
121 buffer,
122 range: start..end,
123 };
124 let context_task = store.update(&mut cx, |store, cx| {
125 let captured_variables = {
126 let mut variables = TaskVariables::from_iter(
127 envelope
128 .payload
129 .task_variables
130 .into_iter()
131 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132 );
133
134 let snapshot = location.buffer.read(cx).snapshot();
135 let range = location.range.to_offset(&snapshot);
136
137 for range in snapshot.runnable_ranges(range) {
138 for (capture_name, value) in range.extra_captures {
139 variables.insert(VariableName::Custom(capture_name.into()), value);
140 }
141 }
142 variables
143 };
144 store.task_context_for_location(captured_variables, location, cx)
145 })?;
146 let task_context = context_task.await.unwrap_or_default();
147 Ok(proto::TaskContext {
148 project_env: task_context.project_env.into_iter().collect(),
149 cwd: task_context
150 .cwd
151 .map(|cwd| cwd.to_string_lossy().to_string()),
152 task_variables: task_context
153 .task_variables
154 .into_iter()
155 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156 .collect(),
157 })
158 }
159
160 pub fn local(
161 buffer_store: WeakEntity<BufferStore>,
162 worktree_store: Entity<WorktreeStore>,
163 toolchain_store: Arc<dyn LanguageToolchainStore>,
164 environment: Entity<ProjectEnvironment>,
165 cx: &mut Context<Self>,
166 ) -> Self {
167 Self::Functional(StoreState {
168 mode: StoreMode::Local {
169 downstream_client: None,
170 environment,
171 },
172 task_inventory: Inventory::new(cx),
173 buffer_store,
174 toolchain_store,
175 worktree_store,
176 })
177 }
178
179 pub fn remote(
180 buffer_store: WeakEntity<BufferStore>,
181 worktree_store: Entity<WorktreeStore>,
182 toolchain_store: Arc<dyn LanguageToolchainStore>,
183 upstream_client: AnyProtoClient,
184 project_id: u64,
185 cx: &mut Context<Self>,
186 ) -> Self {
187 Self::Functional(StoreState {
188 mode: StoreMode::Remote {
189 upstream_client,
190 project_id,
191 },
192 task_inventory: Inventory::new(cx),
193 buffer_store,
194 toolchain_store,
195 worktree_store,
196 })
197 }
198
199 pub fn task_context_for_location(
200 &self,
201 captured_variables: TaskVariables,
202 location: Location,
203 cx: &mut App,
204 ) -> Task<Option<TaskContext>> {
205 match self {
206 TaskStore::Functional(state) => match &state.mode {
207 StoreMode::Local { environment, .. } => local_task_context_for_location(
208 state.worktree_store.clone(),
209 state.toolchain_store.clone(),
210 environment.clone(),
211 captured_variables,
212 location,
213 cx,
214 ),
215 StoreMode::Remote {
216 upstream_client,
217 project_id,
218 } => remote_task_context_for_location(
219 *project_id,
220 upstream_client.clone(),
221 state.worktree_store.clone(),
222 captured_variables,
223 location,
224 state.toolchain_store.clone(),
225 cx,
226 ),
227 },
228 TaskStore::Noop => Task::ready(None),
229 }
230 }
231
232 pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233 match self {
234 TaskStore::Functional(state) => Some(&state.task_inventory),
235 TaskStore::Noop => None,
236 }
237 }
238
239 pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240 if let Self::Functional(StoreState {
241 mode: StoreMode::Local {
242 downstream_client, ..
243 },
244 ..
245 }) = self
246 {
247 *downstream_client = Some((new_downstream_client, remote_id));
248 }
249 }
250
251 pub fn unshared(&mut self, _: &mut Context<Self>) {
252 if let Self::Functional(StoreState {
253 mode: StoreMode::Local {
254 downstream_client, ..
255 },
256 ..
257 }) = self
258 {
259 *downstream_client = None;
260 }
261 }
262
263 pub(super) fn update_user_tasks(
264 &self,
265 location: TaskSettingsLocation<'_>,
266 raw_tasks_json: Option<&str>,
267 cx: &mut Context<Self>,
268 ) -> Result<(), InvalidSettingsError> {
269 let task_inventory = match self {
270 TaskStore::Functional(state) => &state.task_inventory,
271 TaskStore::Noop => return Ok(()),
272 };
273 let raw_tasks_json = raw_tasks_json
274 .map(|json| json.trim())
275 .filter(|json| !json.is_empty());
276
277 task_inventory.update(cx, |inventory, _| {
278 inventory.update_file_based_tasks(location, raw_tasks_json)
279 })
280 }
281
282 pub(super) fn update_user_debug_scenarios(
283 &self,
284 location: TaskSettingsLocation<'_>,
285 raw_tasks_json: Option<&str>,
286 cx: &mut Context<Self>,
287 ) -> Result<(), InvalidSettingsError> {
288 let task_inventory = match self {
289 TaskStore::Functional(state) => &state.task_inventory,
290 TaskStore::Noop => return Ok(()),
291 };
292 let raw_tasks_json = raw_tasks_json
293 .map(|json| json.trim())
294 .filter(|json| !json.is_empty());
295
296 task_inventory.update(cx, |inventory, _| {
297 inventory.update_file_based_scenarios(location, raw_tasks_json)
298 })
299 }
300}
301
302fn local_task_context_for_location(
303 worktree_store: Entity<WorktreeStore>,
304 toolchain_store: Arc<dyn LanguageToolchainStore>,
305 environment: Entity<ProjectEnvironment>,
306 captured_variables: TaskVariables,
307 location: Location,
308 cx: &App,
309) -> Task<Option<TaskContext>> {
310 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
311 let worktree_abs_path = worktree_id
312 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
313 .and_then(|worktree| worktree.read(cx).root_dir());
314
315 cx.spawn(async move |cx| {
316 let project_env = environment
317 .update(cx, |environment, cx| {
318 environment.get_buffer_environment(&location.buffer, &worktree_store, cx)
319 })
320 .ok()?
321 .await;
322
323 let mut task_variables = cx
324 .update(|cx| {
325 combine_task_variables(
326 captured_variables,
327 location,
328 project_env.clone(),
329 BasicContextProvider::new(worktree_store),
330 toolchain_store,
331 cx,
332 )
333 })
334 .ok()?
335 .await
336 .log_err()?;
337 // Remove all custom entries starting with _, as they're not intended for use by the end user.
338 task_variables.sweep();
339
340 Some(TaskContext {
341 project_env: project_env.unwrap_or_default(),
342 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
343 task_variables,
344 })
345 })
346}
347
348fn remote_task_context_for_location(
349 project_id: u64,
350 upstream_client: AnyProtoClient,
351 worktree_store: Entity<WorktreeStore>,
352 captured_variables: TaskVariables,
353 location: Location,
354 toolchain_store: Arc<dyn LanguageToolchainStore>,
355 cx: &mut App,
356) -> Task<Option<TaskContext>> {
357 cx.spawn(async move |cx| {
358 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
359 let mut remote_context = cx
360 .update(|cx| {
361 BasicContextProvider::new(worktree_store).build_context(
362 &TaskVariables::default(),
363 &location,
364 None,
365 toolchain_store,
366 cx,
367 )
368 })
369 .ok()?
370 .await
371 .log_err()
372 .unwrap_or_default();
373 remote_context.extend(captured_variables);
374
375 let buffer_id = cx
376 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
377 .ok()?;
378 let context_task = upstream_client.request(proto::TaskContextForLocation {
379 project_id,
380 location: Some(proto::Location {
381 buffer_id,
382 start: Some(serialize_anchor(&location.range.start)),
383 end: Some(serialize_anchor(&location.range.end)),
384 }),
385 task_variables: remote_context
386 .into_iter()
387 .map(|(k, v)| (k.to_string(), v))
388 .collect(),
389 });
390 let task_context = context_task.await.log_err()?;
391 Some(TaskContext {
392 cwd: task_context.cwd.map(PathBuf::from),
393 task_variables: task_context
394 .task_variables
395 .into_iter()
396 .filter_map(
397 |(variable_name, variable_value)| match variable_name.parse() {
398 Ok(variable_name) => Some((variable_name, variable_value)),
399 Err(()) => {
400 log::error!("Unknown variable name: {variable_name}");
401 None
402 }
403 },
404 )
405 .collect(),
406 project_env: task_context.project_env.into_iter().collect(),
407 })
408 })
409}
410
411fn combine_task_variables(
412 mut captured_variables: TaskVariables,
413 location: Location,
414 project_env: Option<HashMap<String, String>>,
415 baseline: BasicContextProvider,
416 toolchain_store: Arc<dyn LanguageToolchainStore>,
417 cx: &mut App,
418) -> Task<anyhow::Result<TaskVariables>> {
419 let language_context_provider = location
420 .buffer
421 .read(cx)
422 .language()
423 .and_then(|language| language.context_provider());
424 cx.spawn(async move |cx| {
425 let baseline = cx
426 .update(|cx| {
427 baseline.build_context(
428 &captured_variables,
429 &location,
430 project_env.clone(),
431 toolchain_store.clone(),
432 cx,
433 )
434 })?
435 .await
436 .context("building basic default context")?;
437 captured_variables.extend(baseline);
438 if let Some(provider) = language_context_provider {
439 captured_variables.extend(
440 cx.update(|cx| {
441 provider.build_context(
442 &captured_variables,
443 &location,
444 project_env,
445 toolchain_store,
446 cx,
447 )
448 })?
449 .await
450 .context("building provider context")?,
451 );
452 }
453 Ok(captured_variables)
454 })
455}