@@ -1,6 +1,7 @@
use anyhow::{anyhow, Result};
use assistant_slash_command::{
- ArgumentCompletion, SlashCommand, SlashCommandOutput, SlashCommandOutputSection,
+ AfterCompletion, ArgumentCompletion, SlashCommand, SlashCommandOutput,
+ SlashCommandOutputSection,
};
use collections::HashMap;
use context_servers::{
@@ -8,7 +9,7 @@ use context_servers::{
protocol::PromptInfo,
};
use gpui::{Task, WeakView, WindowContext};
-use language::LspAdapterDelegate;
+use language::{CodeLabel, LspAdapterDelegate};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use ui::{IconName, SharedString};
@@ -50,12 +51,57 @@ impl SlashCommand for ContextServerSlashCommand {
fn complete_argument(
self: Arc<Self>,
- _arguments: &[String],
+ arguments: &[String],
_cancel: Arc<AtomicBool>,
_workspace: Option<WeakView<Workspace>>,
- _cx: &mut WindowContext,
+ cx: &mut WindowContext,
) -> Task<Result<Vec<ArgumentCompletion>>> {
- Task::ready(Ok(Vec::new()))
+ let server_id = self.server_id.clone();
+ let prompt_name = self.prompt.name.clone();
+ let manager = ContextServerManager::global(cx);
+ let manager = manager.read(cx);
+
+ let (arg_name, arg_val) = match completion_argument(&self.prompt, arguments) {
+ Ok(tp) => tp,
+ Err(e) => {
+ return Task::ready(Err(e));
+ }
+ };
+ if let Some(server) = manager.get_server(&server_id) {
+ cx.foreground_executor().spawn(async move {
+ let Some(protocol) = server.client.read().clone() else {
+ return Err(anyhow!("Context server not initialized"));
+ };
+
+ let completion_result = protocol
+ .completion(
+ context_servers::types::CompletionReference::Prompt(
+ context_servers::types::PromptReference {
+ r#type: context_servers::types::PromptReferenceType::Prompt,
+ name: prompt_name,
+ },
+ ),
+ arg_name,
+ arg_val,
+ )
+ .await?;
+
+ let completions = completion_result
+ .values
+ .into_iter()
+ .map(|value| ArgumentCompletion {
+ label: CodeLabel::plain(value.clone(), None),
+ new_text: value,
+ after_completion: AfterCompletion::Continue,
+ replace_previous_arguments: false,
+ })
+ .collect();
+
+ Ok(completions)
+ })
+ } else {
+ Task::ready(Err(anyhow!("Context server not found")))
+ }
}
fn run(
@@ -102,6 +148,22 @@ impl SlashCommand for ContextServerSlashCommand {
}
}
+fn completion_argument(prompt: &PromptInfo, arguments: &[String]) -> Result<(String, String)> {
+ if arguments.is_empty() {
+ return Err(anyhow!("No arguments given"));
+ }
+
+ match &prompt.arguments {
+ Some(args) if args.len() == 1 => {
+ let arg_name = args[0].name.clone();
+ let arg_value = arguments.join(" ");
+ Ok((arg_name, arg_value))
+ }
+ Some(_) => Err(anyhow!("Prompt must have exactly one argument")),
+ None => Err(anyhow!("Prompt has no arguments")),
+ }
+}
+
fn prompt_arguments(prompt: &PromptInfo, arguments: &[String]) -> Result<HashMap<String, String>> {
match &prompt.arguments {
Some(args) if args.len() > 1 => Err(anyhow!(
@@ -127,6 +127,35 @@ impl InitializedContextServerProtocol {
Ok(response)
}
+
+ pub async fn completion<P: Into<String>>(
+ &self,
+ reference: types::CompletionReference,
+ argument: P,
+ value: P,
+ ) -> Result<types::Completion> {
+ let params = types::CompletionCompleteParams {
+ r#ref: reference,
+ argument: types::CompletionArgument {
+ name: argument.into(),
+ value: value.into(),
+ },
+ };
+ let result: types::CompletionCompleteResponse = self
+ .inner
+ .request(types::RequestType::CompletionComplete.as_str(), params)
+ .await?;
+
+ let completion = types::Completion {
+ values: result.completion.values,
+ total: types::CompletionTotal::from_options(
+ result.completion.has_more,
+ result.completion.total,
+ ),
+ };
+
+ Ok(completion)
+ }
}
impl InitializedContextServerProtocol {
@@ -14,6 +14,7 @@ pub enum RequestType {
LoggingSetLevel,
PromptsGet,
PromptsList,
+ CompletionComplete,
}
impl RequestType {
@@ -28,6 +29,7 @@ impl RequestType {
RequestType::LoggingSetLevel => "logging/setLevel",
RequestType::PromptsGet => "prompts/get",
RequestType::PromptsList => "prompts/list",
+ RequestType::CompletionComplete => "completion/complete",
}
}
}
@@ -78,6 +80,50 @@ pub struct PromptsGetParams {
pub arguments: Option<HashMap<String, String>>,
}
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionCompleteParams {
+ pub r#ref: CompletionReference,
+ pub argument: CompletionArgument,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(untagged)]
+pub enum CompletionReference {
+ Prompt(PromptReference),
+ Resource(ResourceReference),
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptReference {
+ pub r#type: PromptReferenceType,
+ pub name: String,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "snake_case")]
+pub enum PromptReferenceType {
+ #[serde(rename = "ref/prompt")]
+ Prompt,
+ #[serde(rename = "ref/resource")]
+ Resource,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ResourceReference {
+ pub r#type: String,
+ pub uri: String,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionArgument {
+ pub name: String,
+ pub value: String,
+}
+
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
@@ -112,6 +158,20 @@ pub struct PromptsListResponse {
pub prompts: Vec<PromptInfo>,
}
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionCompleteResponse {
+ pub completion: CompletionResult,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CompletionResult {
+ pub values: Vec<String>,
+ pub total: Option<u32>,
+ pub has_more: Option<bool>,
+}
+
#[derive(Debug, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PromptInfo {
@@ -233,3 +293,26 @@ pub struct ProgressParams {
pub progress: f64,
pub total: Option<f64>,
}
+
+// Helper Types that don't map directly to the protocol
+
+pub enum CompletionTotal {
+ Exact(u32),
+ HasMore,
+ Unknown,
+}
+
+impl CompletionTotal {
+ pub fn from_options(has_more: Option<bool>, total: Option<u32>) -> Self {
+ match (has_more, total) {
+ (_, Some(count)) => CompletionTotal::Exact(count),
+ (Some(true), _) => CompletionTotal::HasMore,
+ _ => CompletionTotal::Unknown,
+ }
+ }
+}
+
+pub struct Completion {
+ pub values: Vec<String>,
+ pub total: CompletionTotal,
+}