1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4};
5
6use anyhow::Context as _;
7use collections::HashMap;
8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
9use language::{
10 ContextProvider as _, LanguageToolchainStore, Location,
11 proto::{deserialize_anchor, serialize_anchor},
12};
13use rpc::{AnyProtoClient, TypedEnvelope, proto};
14use settings::{InvalidSettingsError, SettingsLocation, TaskKind};
15use task::{TaskContext, TaskVariables, VariableName};
16use text::{BufferId, OffsetRangeExt};
17use util::ResultExt;
18
19use crate::{
20 BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
21 worktree_store::WorktreeStore,
22};
23
24#[allow(clippy::large_enum_variant)] // platform-dependent warning
25pub enum TaskStore {
26 Functional(StoreState),
27 Noop,
28}
29
30pub struct StoreState {
31 mode: StoreMode,
32 task_inventory: Entity<Inventory>,
33 buffer_store: WeakEntity<BufferStore>,
34 worktree_store: Entity<WorktreeStore>,
35 toolchain_store: Arc<dyn LanguageToolchainStore>,
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Entity<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51#[derive(Debug)]
52pub enum TaskSettingsLocation<'a> {
53 Global(&'a Path),
54 Worktree(SettingsLocation<'a>),
55}
56
57impl TaskStore {
58 pub fn init(client: Option<&AnyProtoClient>) {
59 if let Some(client) = client {
60 client.add_entity_request_handler(Self::handle_task_context_for_location);
61 }
62 }
63
64 async fn handle_task_context_for_location(
65 store: Entity<Self>,
66 envelope: TypedEnvelope<proto::TaskContextForLocation>,
67 mut cx: AsyncApp,
68 ) -> anyhow::Result<proto::TaskContext> {
69 let location = envelope
70 .payload
71 .location
72 .context("no location given for task context handling")?;
73 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
74 Ok(match store {
75 TaskStore::Functional(state) => (
76 state.buffer_store.clone(),
77 match &state.mode {
78 StoreMode::Local { .. } => false,
79 StoreMode::Remote { .. } => true,
80 },
81 ),
82 TaskStore::Noop => {
83 anyhow::bail!("empty task store cannot handle task context requests")
84 }
85 })
86 })??;
87 let buffer_store = buffer_store
88 .upgrade()
89 .context("no buffer store when handling task context request")?;
90
91 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
92 format!(
93 "cannot handle task context request for invalid buffer id: {}",
94 location.buffer_id
95 )
96 })?;
97
98 let start = location
99 .start
100 .and_then(deserialize_anchor)
101 .context("missing task context location start")?;
102 let end = location
103 .end
104 .and_then(deserialize_anchor)
105 .context("missing task context location end")?;
106 let buffer = buffer_store
107 .update(&mut cx, |buffer_store, cx| {
108 if is_remote {
109 buffer_store.wait_for_remote_buffer(buffer_id, cx)
110 } else {
111 Task::ready(
112 buffer_store
113 .get(buffer_id)
114 .with_context(|| format!("no local buffer with id {buffer_id}")),
115 )
116 }
117 })?
118 .await?;
119
120 let location = Location {
121 buffer,
122 range: start..end,
123 };
124 let context_task = store.update(&mut cx, |store, cx| {
125 let captured_variables = {
126 let mut variables = TaskVariables::from_iter(
127 envelope
128 .payload
129 .task_variables
130 .into_iter()
131 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132 );
133
134 let snapshot = location.buffer.read(cx).snapshot();
135 let range = location.range.to_offset(&snapshot);
136
137 for range in snapshot.runnable_ranges(range) {
138 for (capture_name, value) in range.extra_captures {
139 variables.insert(VariableName::Custom(capture_name.into()), value);
140 }
141 }
142 variables
143 };
144 store.task_context_for_location(captured_variables, location, cx)
145 })?;
146 let task_context = context_task.await.unwrap_or_default();
147 Ok(proto::TaskContext {
148 project_env: task_context.project_env.into_iter().collect(),
149 cwd: task_context
150 .cwd
151 .map(|cwd| cwd.to_string_lossy().to_string()),
152 task_variables: task_context
153 .task_variables
154 .into_iter()
155 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156 .collect(),
157 })
158 }
159
160 pub fn local(
161 buffer_store: WeakEntity<BufferStore>,
162 worktree_store: Entity<WorktreeStore>,
163 toolchain_store: Arc<dyn LanguageToolchainStore>,
164 environment: Entity<ProjectEnvironment>,
165 cx: &mut Context<Self>,
166 ) -> Self {
167 Self::Functional(StoreState {
168 mode: StoreMode::Local {
169 downstream_client: None,
170 environment,
171 },
172 task_inventory: Inventory::new(cx),
173 buffer_store,
174 toolchain_store,
175 worktree_store,
176 })
177 }
178
179 pub fn remote(
180 buffer_store: WeakEntity<BufferStore>,
181 worktree_store: Entity<WorktreeStore>,
182 toolchain_store: Arc<dyn LanguageToolchainStore>,
183 upstream_client: AnyProtoClient,
184 project_id: u64,
185 cx: &mut Context<Self>,
186 ) -> Self {
187 Self::Functional(StoreState {
188 mode: StoreMode::Remote {
189 upstream_client,
190 project_id,
191 },
192 task_inventory: Inventory::new(cx),
193 buffer_store,
194 toolchain_store,
195 worktree_store,
196 })
197 }
198
199 pub fn task_context_for_location(
200 &self,
201 captured_variables: TaskVariables,
202 location: Location,
203 cx: &mut App,
204 ) -> Task<Option<TaskContext>> {
205 match self {
206 TaskStore::Functional(state) => match &state.mode {
207 StoreMode::Local { environment, .. } => local_task_context_for_location(
208 state.worktree_store.clone(),
209 state.toolchain_store.clone(),
210 environment.clone(),
211 captured_variables,
212 location,
213 cx,
214 ),
215 StoreMode::Remote {
216 upstream_client,
217 project_id,
218 } => remote_task_context_for_location(
219 *project_id,
220 upstream_client.clone(),
221 state.worktree_store.clone(),
222 captured_variables,
223 location,
224 state.toolchain_store.clone(),
225 cx,
226 ),
227 },
228 TaskStore::Noop => Task::ready(None),
229 }
230 }
231
232 pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233 match self {
234 TaskStore::Functional(state) => Some(&state.task_inventory),
235 TaskStore::Noop => None,
236 }
237 }
238
239 pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240 if let Self::Functional(StoreState {
241 mode: StoreMode::Local {
242 downstream_client, ..
243 },
244 ..
245 }) = self
246 {
247 *downstream_client = Some((new_downstream_client, remote_id));
248 }
249 }
250
251 pub fn unshared(&mut self, _: &mut Context<Self>) {
252 if let Self::Functional(StoreState {
253 mode: StoreMode::Local {
254 downstream_client, ..
255 },
256 ..
257 }) = self
258 {
259 *downstream_client = None;
260 }
261 }
262
263 pub(super) fn update_user_tasks(
264 &self,
265 location: TaskSettingsLocation<'_>,
266 raw_tasks_json: Option<&str>,
267 task_type: TaskKind,
268 cx: &mut Context<Self>,
269 ) -> Result<(), InvalidSettingsError> {
270 let task_inventory = match self {
271 TaskStore::Functional(state) => &state.task_inventory,
272 TaskStore::Noop => return Ok(()),
273 };
274 let raw_tasks_json = raw_tasks_json
275 .map(|json| json.trim())
276 .filter(|json| !json.is_empty());
277
278 task_inventory.update(cx, |inventory, _| {
279 inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
280 })
281 }
282}
283
284fn local_task_context_for_location(
285 worktree_store: Entity<WorktreeStore>,
286 toolchain_store: Arc<dyn LanguageToolchainStore>,
287 environment: Entity<ProjectEnvironment>,
288 captured_variables: TaskVariables,
289 location: Location,
290 cx: &App,
291) -> Task<Option<TaskContext>> {
292 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
293 let worktree_abs_path = worktree_id
294 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
295 .and_then(|worktree| worktree.read(cx).root_dir());
296
297 cx.spawn(async move |cx| {
298 let project_env = environment
299 .update(cx, |environment, cx| {
300 environment.get_buffer_environment(
301 location.buffer.clone(),
302 worktree_store.clone(),
303 cx,
304 )
305 })
306 .ok()?
307 .await;
308
309 let mut task_variables = cx
310 .update(|cx| {
311 combine_task_variables(
312 captured_variables,
313 location,
314 project_env.clone(),
315 BasicContextProvider::new(worktree_store),
316 toolchain_store,
317 cx,
318 )
319 })
320 .ok()?
321 .await
322 .log_err()?;
323 // Remove all custom entries starting with _, as they're not intended for use by the end user.
324 task_variables.sweep();
325
326 Some(TaskContext {
327 project_env: project_env.unwrap_or_default(),
328 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
329 task_variables,
330 })
331 })
332}
333
334fn remote_task_context_for_location(
335 project_id: u64,
336 upstream_client: AnyProtoClient,
337 worktree_store: Entity<WorktreeStore>,
338 captured_variables: TaskVariables,
339 location: Location,
340 toolchain_store: Arc<dyn LanguageToolchainStore>,
341 cx: &mut App,
342) -> Task<Option<TaskContext>> {
343 cx.spawn(async move |cx| {
344 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
345 let mut remote_context = cx
346 .update(|cx| {
347 BasicContextProvider::new(worktree_store).build_context(
348 &TaskVariables::default(),
349 &location,
350 None,
351 toolchain_store,
352 cx,
353 )
354 })
355 .ok()?
356 .await
357 .log_err()
358 .unwrap_or_default();
359 remote_context.extend(captured_variables);
360
361 let buffer_id = cx
362 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
363 .ok()?;
364 let context_task = upstream_client.request(proto::TaskContextForLocation {
365 project_id,
366 location: Some(proto::Location {
367 buffer_id,
368 start: Some(serialize_anchor(&location.range.start)),
369 end: Some(serialize_anchor(&location.range.end)),
370 }),
371 task_variables: remote_context
372 .into_iter()
373 .map(|(k, v)| (k.to_string(), v))
374 .collect(),
375 });
376 let task_context = context_task.await.log_err()?;
377 Some(TaskContext {
378 cwd: task_context.cwd.map(PathBuf::from),
379 task_variables: task_context
380 .task_variables
381 .into_iter()
382 .filter_map(
383 |(variable_name, variable_value)| match variable_name.parse() {
384 Ok(variable_name) => Some((variable_name, variable_value)),
385 Err(()) => {
386 log::error!("Unknown variable name: {variable_name}");
387 None
388 }
389 },
390 )
391 .collect(),
392 project_env: task_context.project_env.into_iter().collect(),
393 })
394 })
395}
396
397fn combine_task_variables(
398 mut captured_variables: TaskVariables,
399 location: Location,
400 project_env: Option<HashMap<String, String>>,
401 baseline: BasicContextProvider,
402 toolchain_store: Arc<dyn LanguageToolchainStore>,
403 cx: &mut App,
404) -> Task<anyhow::Result<TaskVariables>> {
405 let language_context_provider = location
406 .buffer
407 .read(cx)
408 .language()
409 .and_then(|language| language.context_provider());
410 cx.spawn(async move |cx| {
411 let baseline = cx
412 .update(|cx| {
413 baseline.build_context(
414 &captured_variables,
415 &location,
416 project_env.clone(),
417 toolchain_store.clone(),
418 cx,
419 )
420 })?
421 .await
422 .context("building basic default context")?;
423 captured_variables.extend(baseline);
424 if let Some(provider) = language_context_provider {
425 captured_variables.extend(
426 cx.update(|cx| {
427 provider.build_context(
428 &captured_variables,
429 &location,
430 project_env,
431 toolchain_store,
432 cx,
433 )
434 })?
435 .await
436 .context("building provider context")?,
437 );
438 }
439 Ok(captured_variables)
440 })
441}