1use super::*;
2use gpui::TestAppContext;
3use http_client::FakeHttpClient;
4use language_model::{LanguageModelRequest, MessageContent, Role};
5
6#[gpui::test]
7fn test_local_provider_creation(cx: &mut TestAppContext) {
8 let http_client = FakeHttpClient::with_200_response();
9 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
10
11 cx.read(|cx| {
12 assert_eq!(provider.id(), PROVIDER_ID);
13 assert_eq!(provider.name(), PROVIDER_NAME);
14 assert!(!provider.is_authenticated(cx));
15 assert_eq!(provider.provided_models(cx).len(), 1);
16 });
17}
18
19#[gpui::test]
20fn test_state_initialization(cx: &mut TestAppContext) {
21 cx.update(|cx| {
22 let state = cx.new(State::new);
23
24 assert!(!state.read(cx).is_authenticated());
25 assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
26 assert!(state.read(cx).model.is_none());
27 });
28}
29
30#[gpui::test]
31fn test_model_properties(cx: &mut TestAppContext) {
32 let http_client = FakeHttpClient::with_200_response();
33 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
34
35 // Create a model directly for testing (bypassing authentication)
36 let model = LocalLanguageModel {
37 state: provider.state.clone(),
38 request_limiter: RateLimiter::new(4),
39 };
40
41 assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
42 assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
43 assert_eq!(model.provider_id(), PROVIDER_ID);
44 assert_eq!(model.provider_name(), PROVIDER_NAME);
45 assert_eq!(model.max_token_count(), 128000);
46 assert!(!model.supports_tools());
47 assert!(!model.supports_images());
48}
49
50#[gpui::test]
51async fn test_token_counting(cx: &mut TestAppContext) {
52 let http_client = FakeHttpClient::with_200_response();
53 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
54
55 let model = LocalLanguageModel {
56 state: provider.state.clone(),
57 request_limiter: RateLimiter::new(4),
58 };
59
60 let request = LanguageModelRequest {
61 thread_id: None,
62 prompt_id: None,
63 intent: None,
64 mode: None,
65 messages: vec![language_model::LanguageModelRequestMessage {
66 role: Role::User,
67 content: vec![MessageContent::Text("Hello, world!".to_string())],
68 cache: false,
69 }],
70 tools: Vec::new(),
71 tool_choice: None,
72 stop: Vec::new(),
73 temperature: None,
74 thinking_allowed: false,
75 };
76
77 let count = cx
78 .update(|cx| model.count_tokens(request, cx))
79 .await
80 .unwrap();
81
82 // "Hello, world!" is 13 characters, so ~3 tokens
83 assert!(count > 0);
84 assert!(count < 10);
85}
86
87#[gpui::test]
88async fn test_message_conversion(cx: &mut TestAppContext) {
89 let http_client = FakeHttpClient::with_200_response();
90 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
91
92 let model = LocalLanguageModel {
93 state: provider.state.clone(),
94 request_limiter: RateLimiter::new(4),
95 };
96
97 let request = LanguageModelRequest {
98 thread_id: None,
99 prompt_id: None,
100 intent: None,
101 mode: None,
102 messages: vec![
103 language_model::LanguageModelRequestMessage {
104 role: Role::System,
105 content: vec![MessageContent::Text(
106 "You are a helpful assistant.".to_string(),
107 )],
108 cache: false,
109 },
110 language_model::LanguageModelRequestMessage {
111 role: Role::User,
112 content: vec![MessageContent::Text("Hello!".to_string())],
113 cache: false,
114 },
115 language_model::LanguageModelRequestMessage {
116 role: Role::Assistant,
117 content: vec![MessageContent::Text("Hi there!".to_string())],
118 cache: false,
119 },
120 ],
121 tools: Vec::new(),
122 tool_choice: None,
123 stop: Vec::new(),
124 temperature: None,
125 thinking_allowed: false,
126 };
127
128 let _messages = model.to_mistral_messages(&request);
129 // We can't directly inspect TextMessages, but we can verify it doesn't panic
130 assert!(true); // Placeholder assertion
131}
132
133#[gpui::test]
134async fn test_reset_credentials(cx: &mut TestAppContext) {
135 let http_client = FakeHttpClient::with_200_response();
136 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
137
138 // Simulate loading a model by just setting the status
139 cx.update(|cx| {
140 provider.state.update(cx, |state, cx| {
141 state.status = ModelStatus::Loaded;
142 // We don't actually set a model since we can't mock it safely
143 cx.notify();
144 });
145 });
146
147 cx.read(|cx| {
148 // Since is_authenticated checks for model presence, we need to check status directly
149 assert_eq!(provider.state.read(cx).status, ModelStatus::Loaded);
150 });
151
152 // Reset credentials
153 let task = cx.update(|cx| provider.reset_credentials(cx));
154 task.await.unwrap();
155
156 cx.read(|cx| {
157 assert!(!provider.is_authenticated(cx));
158 assert_eq!(provider.state.read(cx).status, ModelStatus::NotLoaded);
159 assert!(provider.state.read(cx).model.is_none());
160 });
161}
162
163// TODO: Fix this test - need to handle window creation in tests
164// #[gpui::test]
165// async fn test_configuration_view_rendering(cx: &mut TestAppContext) {
166// let http_client = FakeHttpClient::with_200_response();
167// let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
168
169// let view = cx.update(|cx| provider.configuration_view(cx.window(), cx));
170
171// // Basic test to ensure the view can be created without panicking
172// assert!(view.entity_type() == std::any::TypeId::of::<ConfigurationView>());
173// }
174
175#[gpui::test]
176fn test_status_transitions(cx: &mut TestAppContext) {
177 cx.update(|cx| {
178 let state = cx.new(State::new);
179
180 // Initial state
181 assert_eq!(state.read(cx).status, ModelStatus::NotLoaded);
182
183 // Transition to loading
184 state.update(cx, |state, cx| {
185 state.status = ModelStatus::Loading;
186 cx.notify();
187 });
188 assert_eq!(state.read(cx).status, ModelStatus::Loading);
189
190 // Transition to loaded
191 state.update(cx, |state, cx| {
192 state.status = ModelStatus::Loaded;
193 cx.notify();
194 });
195 assert_eq!(state.read(cx).status, ModelStatus::Loaded);
196
197 // Transition to error
198 state.update(cx, |state, cx| {
199 state.status = ModelStatus::Error("Test error".to_string());
200 cx.notify();
201 });
202 match &state.read(cx).status {
203 ModelStatus::Error(msg) => assert_eq!(msg, "Test error"),
204 _ => panic!("Expected error status"),
205 }
206 });
207}
208
209#[gpui::test]
210fn test_provider_shows_models_without_authentication(cx: &mut TestAppContext) {
211 let http_client = FakeHttpClient::with_200_response();
212 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
213
214 cx.read(|cx| {
215 // Provider should show models even when not authenticated
216 let models = provider.provided_models(cx);
217 assert_eq!(models.len(), 1);
218
219 let model = &models[0];
220 assert_eq!(model.id(), LanguageModelId(DEFAULT_MODEL.into()));
221 assert_eq!(model.name(), LanguageModelName(DEFAULT_MODEL.into()));
222 assert_eq!(model.provider_id(), PROVIDER_ID);
223 assert_eq!(model.provider_name(), PROVIDER_NAME);
224 });
225}
226
227#[gpui::test]
228fn test_provider_has_icon(cx: &mut TestAppContext) {
229 let http_client = FakeHttpClient::with_200_response();
230 let provider = cx.update(|cx| LocalLanguageModelProvider::new(Arc::new(http_client), cx));
231
232 assert_eq!(provider.icon(), IconName::Ai);
233}
234
235#[gpui::test]
236fn test_provider_appears_in_registry(cx: &mut TestAppContext) {
237 use language_model::LanguageModelRegistry;
238
239 cx.update(|cx| {
240 let registry = cx.new(|_| LanguageModelRegistry::default());
241 let http_client = FakeHttpClient::with_200_response();
242
243 // Register the local provider
244 registry.update(cx, |registry, cx| {
245 let provider = LocalLanguageModelProvider::new(Arc::new(http_client), cx);
246 registry.register_provider(provider, cx);
247 });
248
249 // Verify the provider is registered
250 let provider = registry.read(cx).provider(&PROVIDER_ID).unwrap();
251 assert_eq!(provider.name(), PROVIDER_NAME);
252 assert_eq!(provider.icon(), IconName::Ai);
253
254 // Verify it provides models even without authentication
255 let models = provider.provided_models(cx);
256 assert_eq!(models.len(), 1);
257 assert_eq!(models[0].id(), LanguageModelId(DEFAULT_MODEL.into()));
258 });
259}