@@ -4,9 +4,12 @@ use std::{
};
use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use parking_lot::Mutex;
use crate::{
auth::{CredentialProvider, NullCredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
embedding::{Embedding, EmbeddingProvider},
models::{LanguageModel, TruncationDirection},
};
@@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
}
}
+
+pub struct TestCompletionProvider {
+ last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl TestCompletionProvider {
+ pub fn new() -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+
+ pub fn send_completion(&self, completion: impl Into<String>) {
+ let mut tx = self.last_completion_tx.lock();
+ tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+ }
+
+ pub fn finish_completion(&self) {
+ self.last_completion_tx.lock().take().unwrap();
+ }
+}
+
+impl CompletionProvider for TestCompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+ model
+ }
+ fn complete(
+ &self,
+ _prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+ let (tx, rx) = mpsc::channel(1);
+ *self.last_completion_tx.lock() = Some(tx);
+ async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+ }
+}
@@ -44,6 +44,7 @@ tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
+ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true
env_logger.workspace = true
@@ -335,7 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
- use ai::{models::LanguageModel, test::FakeLanguageModel};
+ use ai::test::TestCompletionProvider;
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
@@ -617,42 +617,6 @@ mod tests {
}
}
- struct TestCompletionProvider {
- last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
- }
-
- impl TestCompletionProvider {
- fn new() -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-
- fn send_completion(&self, completion: impl Into<String>) {
- let mut tx = self.last_completion_tx.lock();
- tx.as_mut().unwrap().try_send(completion.into()).unwrap();
- }
-
- fn finish_completion(&self) {
- self.last_completion_tx.lock().take().unwrap();
- }
- }
-
- impl CompletionProvider for TestCompletionProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
- model
- }
- fn complete(
- &self,
- _prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let (tx, rx) = mpsc::channel(1);
- *self.last_completion_tx.lock() = Some(tx);
- async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
- }
- }
-
fn rust_lang() -> Language {
Language::new(
LanguageConfig {