tests.rs

  1use super::*;
  2use gpui::TestAppContext;
  3use http_client::FakeHttpClient;
  4use language_model::{LanguageModelRequest, MessageContent, Role};
  5
  6#[gpui::test]
  7fn test_local_provider_creation(cx: &mut TestAppContext) {
  8    let http_client = FakeHttpClient::with_200_response();
  9    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
 10
 11    cx.read(|cx| {
 12        assert_eq!(provider.id(), PROVIDER_ID);
 13        assert_eq!(provider.name(), PROVIDER_NAME);
 14        assert!(!provider.is_authenticated(cx));
 15        assert_eq!(provider.provided_models(cx).len(), 1);
 16    });
 17}
 18
 19#[gpui::test]
 20fn test_state_initialization(cx: &mut TestAppContext) {
 21    cx.update(|cx| {
 22        let state = cx.new(State::new);
 23
 24        assert!(!state.read(cx).is_authenticated());
 25        assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
 26        assert!(state.read(cx).model.is_none());
 27    });
 28}
 29
 30#[gpui::test]
 31fn test_model_properties(cx: &mut TestAppContext) {
 32    let http_client = FakeHttpClient::with_200_response();
 33    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
 34
 35    // Create a model directly for testing (bypassing authentication)
 36    let model = LocalLanguageModel {
 37        state: provider.state.clone(),
 38        request_limiter: RateLimiter::new(4),
 39    };
 40
 41    assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
 42    assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
 43    assert_eq!(model.provider_id(), PROVIDER_ID);
 44    assert_eq!(model.provider_name(), PROVIDER_NAME);
 45    assert_eq!(model.max_token_count(), 128000);
 46    assert!(!model.supports_tools());
 47    assert!(!model.supports_images());
 48}
 49
 50#[gpui::test]
 51async fn test_token_counting(cx: &mut TestAppContext) {
 52    let http_client = FakeHttpClient::with_200_response();
 53    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
 54
 55    let model = LocalLanguageModel {
 56        state: provider.state.clone(),
 57        request_limiter: RateLimiter::new(4),
 58    };
 59
 60    let request = LanguageModelRequest {
 61        thread_id: None,
 62        prompt_id: None,
 63        intent: None,
 64        mode: None,
 65        messages: vec![language_model::LanguageModelRequestMessage {
 66            role: Role::User,
 67            content: vec![MessageContent::Text("Hello, world!".to_string())],
 68            cache: false,
 69        }],
 70        tools: Vec::new(),
 71        tool_choice: None,
 72        stop: Vec::new(),
 73        temperature: None,
 74        thinking_allowed: false,
 75    };
 76
 77    let count = cx
 78        .update(|cx| model.count_tokens(request, cx))
 79        .await
 80        .unwrap();
 81
 82    // "Hello, world!" is 13 characters, so ~3 tokens
 83    assert!(count > 0);
 84    assert!(count < 10);
 85}
 86
 87#[gpui::test]
 88async fn test_message_conversion(cx: &mut TestAppContext) {
 89    let http_client = FakeHttpClient::with_200_response();
 90    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
 91
 92    let model = LocalLanguageModel {
 93        state: provider.state.clone(),
 94        request_limiter: RateLimiter::new(4),
 95    };
 96
 97    let request = LanguageModelRequest {
 98        thread_id: None,
 99        prompt_id: None,
100        intent: None,
101        mode: None,
102        messages: vec![
103            language_model::LanguageModelRequestMessage {
104                role: Role::System,
105                content: vec![MessageContent::Text(
106                    "You are a helpful assistant.".to_string(),
107                )],
108                cache: false,
109            },
110            language_model::LanguageModelRequestMessage {
111                role: Role::User,
112                content: vec![MessageContent::Text("Hello!".to_string())],
113                cache: false,
114            },
115            language_model::LanguageModelRequestMessage {
116                role: Role::Assistant,
117                content: vec![MessageContent::Text("Hi there!".to_string())],
118                cache: false,
119            },
120        ],
121        tools: Vec::new(),
122        tool_choice: None,
123        stop: Vec::new(),
124        temperature: None,
125        thinking_allowed: false,
126    };
127
128    let _messages = model.to_mistral_messages(&request);
129    // We can't directly inspect TextMessages, but we can verify it doesn't panic
130    assert!(true); // Placeholder assertion
131}
132
133#[gpui::test]
134async fn test_reset_credentials(cx: &mut TestAppContext) {
135    let http_client = FakeHttpClient::with_200_response();
136    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
137
138    // Simulate loading a model by just setting the status
139    cx.update(|cx| {
140        provider.state.update(cx, |state, cx| {
141            state.status = ModelStatus::Loaded;
142            // We don't actually set a model since we can't mock it safely
143            cx.notify();
144        });
145    });
146
147    cx.read(|cx| {
148        // Since is_authenticated checks for model presence, we need to check status directly
149        assert_eq!(provider.state.read(cx).status, ModelStatus::Loaded);
150    });
151
152    // Reset credentials
153    let task = cx.update(|cx| provider.reset_credentials(cx));
154    task.await.unwrap();
155
156    cx.read(|cx| {
157        assert!(!provider.is_authenticated(cx));
158        assert_eq!(provider.state.read(cx).status, ModelStatus::NotLoaded);
159        assert!(provider.state.read(cx).model.is_none());
160    });
161}
162
163// TODO: Fix this test - need to handle window creation in tests
164// #[gpui::test]
165// async fn test_configuration_view_rendering(cx: &mut TestAppContext) {
166//     let http_client = FakeHttpClient::with_200_response();
167//     let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
168
169//     let view = cx.update(|cx| provider.configuration_view(cx.window(), cx));
170
171//     // Basic test to ensure the view can be created without panicking
172//     assert!(view.entity_type() == std::any::TypeId::of::<ConfigurationView>());
173// }
174
175#[gpui::test]
176fn test_status_transitions(cx: &mut TestAppContext) {
177    cx.update(|cx| {
178        let state = cx.new(State::new);
179
180        // Initial state
181        assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
182
183        // Transition to loading
184        state.update(cx, |state, cx| {
185            state.status = ModelStatus::Loading;
186            cx.notify();
187        });
188        assert_eq!(state.read(cx).status, ModelStatus::Loading);
189
190        // Transition to loaded
191        state.update(cx, |state, cx| {
192            state.status = ModelStatus::Loaded;
193            cx.notify();
194        });
195        assert_eq!(state.read(cx).status, ModelStatus::Loaded);
196
197        // Transition to error
198        state.update(cx, |state, cx| {
199            state.status = ModelStatus::Error("Test error".to_string());
200            cx.notify();
201        });
202        match &state.read(cx).status {
203            ModelStatus::Error(msg) => assert_eq!(msg, "Test error"),
204            _ => panic!("Expected error status"),
205        }
206    });
207}
208
209#[gpui::test]
210fn test_provider_shows_models_without_authentication(cx: &mut TestAppContext) {
211    let http_client = FakeHttpClient::with_200_response();
212    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
213
214    cx.read(|cx| {
215        // Provider should show models even when not authenticated
216        let models = provider.provided_models(cx);
217        assert_eq!(models.len(), 1);
218
219        let model = &models[0];
220        assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
221        assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
222        assert_eq!(model.provider_id(), PROVIDER_ID);
223        assert_eq!(model.provider_name(), PROVIDER_NAME);
224    });
225}
226
227#[gpui::test]
228fn test_provider_has_icon(cx: &mut TestAppContext) {
229    let http_client = FakeHttpClient::with_200_response();
230    let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
231
232    assert_eq!(provider.icon(), IconName::Ai);
233}
234
235#[gpui::test]
236fn test_provider_appears_in_registry(cx: &mut TestAppContext) {
237    use language_model::LanguageModelRegistry;
238
239    cx.update(|cx| {
240        let registry = cx.new(|_| LanguageModelRegistry::default());
241        let http_client = FakeHttpClient::with_200_response();
242
243        // Register the local provider
244        registry.update(cx, |registry, cx| {
245            let provider = LocalLanguageModelProvider::new(Arc::new(http_client), cx);
246            registry.register_provider(provider, cx);
247        });
248
249        // Verify the provider is registered
250        let provider = registry.read(cx).provider(&PROVIDER_ID).unwrap();
251        assert_eq!(provider.name(), PROVIDER_NAME);
252        assert_eq!(provider.icon(), IconName::Ai);
253
254        // Verify it provides models even without authentication
255        let models = provider.provided_models(cx);
256        assert_eq!(models.len(), 1);
257        assert_eq!(models[0].id(), LanguageModelId(DEFAULT_MODEL.into()));
258    });
259}