diff --git a/cmd/test-ollama/main.go b/cmd/test-ollama-bkp/main.go similarity index 100% rename from cmd/test-ollama/main.go rename to cmd/test-ollama-bkp/main.go diff --git a/internal/config/load.go b/internal/config/load.go index 9d662515af646f8dd37cf2f114f93b23382580ec..f902138d52fa0cdd7beb0f40a3aacd9e225d00d1 100644 --- a/internal/config/load.go +++ b/internal/config/load.go @@ -90,6 +90,27 @@ func Load(workingDir string, debug bool) (*Config, error) { return cfg, nil } +// convertOllamaModels converts ollama.ProviderModel to provider.Model +func convertOllamaModels(ollamaModels []ollama.ProviderModel) []provider.Model { + providerModels := make([]provider.Model, len(ollamaModels)) + for i, model := range ollamaModels { + providerModels[i] = provider.Model{ + ID: model.ID, + Model: model.Model, + CostPer1MIn: model.CostPer1MIn, + CostPer1MOut: model.CostPer1MOut, + CostPer1MInCached: model.CostPer1MInCached, + CostPer1MOutCached: model.CostPer1MOutCached, + ContextWindow: model.ContextWindow, + DefaultMaxTokens: model.DefaultMaxTokens, + CanReason: model.CanReason, + HasReasoningEffort: model.HasReasoningEffort, + SupportsImages: model.SupportsImages, + } + } + return providerModels +} + func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, knownProviders []provider.Provider) error { knownProviderNames := make(map[string]bool) for _, p := range knownProviders { @@ -201,11 +222,11 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn BaseURL: "http://localhost:11434/v1", Type: provider.TypeOpenAI, APIKey: "ollama", - Models: ollamaProvider.Models, + Models: convertOllamaModels(ollamaProvider.Models), } } else { // If Ollama is not running, try to start it - if err := ollama.EnsureOllamaRunning(ctx); err == nil { + if err := ollama.EnsureRunning(ctx); err == nil { // Now try to get the provider again if ollamaProvider, err := ollama.GetProvider(ctx); err == nil { slog.Debug("Started Ollama service and detected provider", "models", len(ollamaProvider.Models)) @@ -215,7 +236,7 @@ func (cfg *Config) configureProviders(env env.Env, resolver VariableResolver, kn BaseURL: "http://localhost:11434/v1", Type: provider.TypeOpenAI, APIKey: "ollama", - Models: ollamaProvider.Models, + Models: convertOllamaModels(ollamaProvider.Models), } } else { slog.Debug("Started Ollama service but failed to get provider", "error", err) diff --git a/internal/ollama/cleanup.go b/internal/ollama/cleanup.go new file mode 100644 index 0000000000000000000000000000000000000000..e9ffe17c89d06b4f4cb16697e4e284c80b414980 --- /dev/null +++ b/internal/ollama/cleanup.go @@ -0,0 +1,87 @@ +package ollama + +import ( + "context" + "os" + "os/exec" + "os/signal" + "syscall" + "time" +) + +// setupCleanup sets up signal handlers for cleanup +func setupCleanup() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + cleanup() + os.Exit(0) + }() +} + +// cleanup stops all running models and service if started by Crush +func cleanup() { + processManager.mu.Lock() + defer processManager.mu.Unlock() + + // Stop all running models using HTTP API + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if IsRunning(ctx) { + stopAllModels(ctx) + } + + // Stop Ollama service if we started it + if processManager.crushStartedOllama && processManager.ollamaProcess != nil { + stopOllamaService() + } +} + +// stopAllModels stops all running models +func stopAllModels(ctx context.Context) { + runningModels, err := GetRunningModels(ctx) + if err != nil { + return + } + + for _, model := range runningModels { + stopModel(ctx, model.Name) + } +} + +// stopModel stops a specific model using CLI +func stopModel(ctx context.Context, modelName string) error { + cmd := exec.CommandContext(ctx, "ollama", "stop", modelName) + return cmd.Run() +} + +// stopOllamaService stops the Ollama service process +func stopOllamaService() { + if processManager.ollamaProcess == nil { + return + } + + // Try graceful shutdown first + if err := processManager.ollamaProcess.Process.Signal(syscall.SIGTERM); err == nil { + // Wait for graceful shutdown + done := make(chan error, 1) + go func() { + done <- processManager.ollamaProcess.Wait() + }() + + select { + case <-done: + // Process finished gracefully + case <-time.After(5 * time.Second): + // Force kill if not shut down gracefully + syscall.Kill(-processManager.ollamaProcess.Process.Pid, syscall.SIGKILL) + processManager.ollamaProcess.Wait() + } + } + + processManager.ollamaProcess = nil + processManager.crushStartedOllama = false +} diff --git a/internal/ollama/cleanup_test.go b/internal/ollama/cleanup_test.go index 3d935f3619dc865e5b13524f9052444089b39f99..e3562ec82f78f3d2fafe52b442d7b129ab3c2f8a 100644 --- a/internal/ollama/cleanup_test.go +++ b/internal/ollama/cleanup_test.go @@ -7,22 +7,33 @@ import ( "time" ) -// TestCleanupOnExit tests that Ollama models are properly stopped when Crush exits -func TestCleanupOnExit(t *testing.T) { +func TestProcessManagementWithRealModel(t *testing.T) { if !IsInstalled() { - t.Skip("Ollama is not installed, skipping cleanup test") + t.Skip("Ollama is not installed, skipping process management test") } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - // Ensure Ollama is running - if !IsRunning(ctx) { + // Start with a clean state + originallyRunning := IsRunning(ctx) + t.Logf("Ollama originally running: %v", originallyRunning) + + // If Ollama wasn't running, we'll start it and be responsible for cleanup + var shouldCleanup bool + if !originallyRunning { + shouldCleanup = true t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { + + if err := StartService(ctx); err != nil { t.Fatalf("Failed to start Ollama service: %v", err) } - defer cleanupProcesses() // Clean up at the end + + if !IsRunning(ctx) { + t.Fatal("Started Ollama service but it's not running") + } + + t.Log("✓ Ollama service started successfully") } // Get available models @@ -32,152 +43,203 @@ func TestCleanupOnExit(t *testing.T) { } if len(models) == 0 { - t.Skip("No models available, skipping cleanup test") + t.Skip("No models available, skipping model loading test") } - // Pick a small model for testing - testModel := models[0].ID + // Choose a test model (prefer smaller models) + testModel := models[0].Name for _, model := range models { - if model.ID == "phi3:3.8b" || model.ID == "llama3.2:3b" { - testModel = model.ID + if model.Name == "phi3:3.8b" || model.Name == "llama3.2:3b" { + testModel = model.Name break } } - t.Logf("Testing cleanup with model: %s", testModel) + t.Logf("Testing with model: %s", testModel) - // Check if model is already loaded - loaded, err := IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded: %v", err) + // Test 1: Load model + t.Log("Loading model...") + startTime := time.Now() + + if err := EnsureModelLoaded(ctx, testModel); err != nil { + t.Fatalf("Failed to load model: %v", err) } - // If not loaded, start it - if !loaded { - t.Log("Starting model for cleanup test...") - if err := StartModel(ctx, testModel); err != nil { - t.Fatalf("Failed to start model: %v", err) - } + loadTime := time.Since(startTime) + t.Logf("✓ Model loaded in %v", loadTime) - // Verify it's now loaded - loaded, err = IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded after start: %v", err) - } - if !loaded { - t.Fatal("Model failed to load") - } - t.Log("Model loaded successfully") - } else { - t.Log("Model was already loaded") + // Verify model is running + running, err := IsModelRunning(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is running: %v", err) + } + + if !running { + t.Fatal("Model should be running but isn't") } - // Now test the cleanup - t.Log("Testing cleanup process...") + t.Log("✓ Model is confirmed running") - // Simulate what happens when Crush exits - cleanupProcesses() + // Test 2: Immediate cleanup after loading + t.Log("Testing immediate cleanup after model load...") - // Give some time for cleanup - time.Sleep(3 * time.Second) + cleanupStart := time.Now() + cleanup() + cleanupTime := time.Since(cleanupStart) - // Check if model is still loaded - loaded, err = IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded after cleanup: %v", err) - } + t.Logf("✓ Cleanup completed in %v", cleanupTime) - if loaded { - t.Error("Model is still loaded after cleanup - cleanup failed") + // Give cleanup time to take effect + time.Sleep(2 * time.Second) + + // Test 3: Verify cleanup worked + if shouldCleanup { + // If we started Ollama, it should be stopped + if IsRunning(ctx) { + t.Error("❌ Ollama service should be stopped after cleanup but it's still running") + } else { + t.Log("✓ Ollama service properly stopped after cleanup") + } } else { - t.Log("Model successfully unloaded after cleanup") + // If Ollama was already running, it should still be running but model should be stopped + if !IsRunning(ctx) { + t.Error("❌ Ollama service should still be running but it's stopped") + } else { + t.Log("✓ Ollama service still running (as expected)") + + // Check if model is still loaded + running, err := IsModelRunning(ctx, testModel) + if err != nil { + t.Errorf("Failed to check model status after cleanup: %v", err) + } else if running { + t.Error("❌ Model should be stopped after cleanup but it's still running") + } else { + t.Log("✓ Model properly stopped after cleanup") + } + } } + + // Test 4: Test cleanup idempotency + t.Log("Testing cleanup idempotency...") + cleanup() + cleanup() + cleanup() + t.Log("✓ Multiple cleanup calls handled safely") } -// TestCleanupWithMockProcess tests cleanup functionality with a mock process func TestCleanupWithMockProcess(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping mock cleanup test") - } - - // Create a mock long-running process to simulate a model + // Test cleanup mechanism with a mock process that simulates Ollama cmd := exec.Command("sleep", "30") if err := cmd.Start(); err != nil { t.Fatalf("Failed to start mock process: %v", err) } - // Add it to our process manager + pid := cmd.Process.Pid + t.Logf("Started mock process with PID: %d", pid) + + // Simulate what happens in our process manager processManager.mu.Lock() - if processManager.processes == nil { - processManager.processes = make(map[string]*exec.Cmd) - } - processManager.processes["mock-model"] = cmd + processManager.ollamaProcess = cmd + processManager.crushStartedOllama = true processManager.mu.Unlock() - t.Logf("Started mock process with PID: %d", cmd.Process.Pid) + // Test cleanup + t.Log("Testing cleanup with mock process...") + stopOllamaService() - // Verify the process is running - if cmd.Process == nil { - t.Fatal("Mock process is nil") + // Verify process was terminated + if cmd.ProcessState != nil && cmd.ProcessState.Exited() { + t.Log("✓ Mock process was successfully terminated") + } else { + // Process might still be terminating + time.Sleep(100 * time.Millisecond) + if cmd.ProcessState != nil && cmd.ProcessState.Exited() { + t.Log("✓ Mock process was successfully terminated") + } else { + t.Error("❌ Mock process was not terminated") + } } +} - // Check if the process is actually running - if cmd.ProcessState != nil && cmd.ProcessState.Exited() { - t.Fatal("Mock process has already exited") +func TestSetupCleanup(t *testing.T) { + // Test that setupCleanup can be called without panicking + defer func() { + if r := recover(); r != nil { + t.Fatalf("setupCleanup panicked: %v", r) + } + }() + + // This should not panic and should be safe to call multiple times + setupCleanup() + t.Log("✓ setupCleanup completed without panic") +} + +func TestStopModel(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping stopModel test") } - // Test cleanup - t.Log("Testing cleanup with mock process...") - cleanupProcesses() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() - // Give some time for cleanup - time.Sleep(1 * time.Second) + // Ensure Ollama is running + if err := EnsureRunning(ctx); err != nil { + t.Fatalf("Failed to ensure Ollama is running: %v", err) + } - // The new CLI-based cleanup only stops Ollama models, not arbitrary processes - // So we need to manually clean up the mock process from our process manager - processManager.mu.Lock() - if mockCmd, exists := processManager.processes["mock-model"]; exists { - if mockCmd.Process != nil { - mockCmd.Process.Kill() - } - delete(processManager.processes, "mock-model") + // Get available models + models, err := GetModels(ctx) + if err != nil { + t.Fatalf("Failed to get models: %v", err) } - processManager.mu.Unlock() - // Manually terminate the mock process since it's not an Ollama model - if cmd.Process != nil { - cmd.Process.Kill() + if len(models) == 0 { + t.Skip("No models available, skipping stopModel test") } - // Give some time for termination - time.Sleep(500 * time.Millisecond) + testModel := models[0].Name + t.Logf("Testing stop with model: %s", testModel) - // Check if process was terminated - if cmd.ProcessState != nil && cmd.ProcessState.Exited() { - t.Log("Mock process was successfully terminated") + // Load the model first + if err := EnsureModelLoaded(ctx, testModel); err != nil { + t.Fatalf("Failed to load model: %v", err) + } + + // Verify it's running + running, err := IsModelRunning(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is running: %v", err) + } + + if !running { + t.Fatal("Model should be running but isn't") + } + + // Test stopping the model + t.Log("Stopping model...") + if err := stopModel(ctx, testModel); err != nil { + t.Fatalf("Failed to stop model: %v", err) + } + + // Give it time to stop + time.Sleep(2 * time.Second) + + // Verify it's stopped + running, err = IsModelRunning(ctx, testModel) + if err != nil { + t.Fatalf("Failed to check if model is running after stop: %v", err) + } + + if running { + t.Error("❌ Model should be stopped but it's still running") } else { - // Try to wait for the process to check its state - if err := cmd.Wait(); err != nil { - t.Log("Mock process was successfully terminated") - } else { - t.Error("Mock process is still running after cleanup") - } + t.Log("✓ Model successfully stopped") } -} -// TestCleanupIdempotency tests that cleanup can be called multiple times safely -func TestCleanupIdempotency(t *testing.T) { - // This test should not panic or cause issues when called multiple times + // Cleanup defer func() { - if r := recover(); r != nil { - t.Fatalf("Cleanup panicked: %v", r) + if processManager.crushStartedOllama { + cleanup() } }() - - // Call cleanup multiple times - cleanupProcesses() - cleanupProcesses() - cleanupProcesses() - - t.Log("Cleanup is idempotent and safe to call multiple times") } diff --git a/internal/ollama/cli.go b/internal/ollama/cli.go deleted file mode 100644 index 2d04f66372159e7ac19a3bc7146eb90325d1b23d..0000000000000000000000000000000000000000 --- a/internal/ollama/cli.go +++ /dev/null @@ -1,208 +0,0 @@ -package ollama - -import ( - "context" - "fmt" - "os/exec" - "strings" - "time" -) - -// CLI-based approach for Ollama operations -// These functions use the ollama CLI instead of HTTP requests - -// CLIListModels lists available models using ollama CLI -func CLIListModels(ctx context.Context) ([]OllamaModel, error) { - cmd := exec.CommandContext(ctx, "ollama", "list") - output, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("failed to list models via CLI: %w", err) - } - - return parseModelsList(string(output)) -} - -// parseModelsList parses the text output from 'ollama list' -func parseModelsList(output string) ([]OllamaModel, error) { - lines := strings.Split(strings.TrimSpace(output), "\n") - if len(lines) < 2 { - return nil, fmt.Errorf("unexpected output format") - } - - var models []OllamaModel - // Skip the header line - for i := 1; i < len(lines); i++ { - line := strings.TrimSpace(lines[i]) - if line == "" { - continue - } - - // Parse each line: NAME ID SIZE MODIFIED - fields := strings.Fields(line) - if len(fields) >= 4 { - name := fields[0] - models = append(models, OllamaModel{ - Name: name, - Model: name, - Size: 0, // Size parsing from text is complex, skip for now - }) - } - } - - return models, nil -} - -// CLIListRunningModels lists currently running models using ollama CLI -func CLIListRunningModels(ctx context.Context) ([]string, error) { - cmd := exec.CommandContext(ctx, "ollama", "ps") - output, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("failed to list running models via CLI: %w", err) - } - - return parseRunningModelsList(string(output)) -} - -// parseRunningModelsList parses the text output from 'ollama ps' -func parseRunningModelsList(output string) ([]string, error) { - lines := strings.Split(strings.TrimSpace(output), "\n") - if len(lines) < 2 { - return []string{}, nil // No running models - } - - var runningModels []string - // Skip the header line - for i := 1; i < len(lines); i++ { - line := strings.TrimSpace(lines[i]) - if line == "" { - continue - } - - // Parse each line: NAME ID SIZE PROCESSOR UNTIL - fields := strings.Fields(line) - if len(fields) >= 1 { - name := fields[0] - if name != "" { - runningModels = append(runningModels, name) - } - } - } - - return runningModels, nil -} - -// CLIStopModel stops a specific model using ollama CLI -func CLIStopModel(ctx context.Context, modelName string) error { - cmd := exec.CommandContext(ctx, "ollama", "stop", modelName) - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to stop model %s via CLI: %w", modelName, err) - } - return nil -} - -// CLIStopAllModels stops all running models using ollama CLI -func CLIStopAllModels(ctx context.Context) error { - // First get list of running models - runningModels, err := CLIListRunningModels(ctx) - if err != nil { - return fmt.Errorf("failed to get running models: %w", err) - } - - // Stop each model individually - for _, modelName := range runningModels { - if err := CLIStopModel(ctx, modelName); err != nil { - return fmt.Errorf("failed to stop model %s: %w", modelName, err) - } - } - - return nil -} - -// CLIIsModelRunning checks if a specific model is running using ollama CLI -func CLIIsModelRunning(ctx context.Context, modelName string) (bool, error) { - runningModels, err := CLIListRunningModels(ctx) - if err != nil { - return false, err - } - - for _, running := range runningModels { - if running == modelName { - return true, nil - } - } - - return false, nil -} - -// CLIStartModel starts a model using ollama CLI (similar to StartModel but using CLI) -func CLIStartModel(ctx context.Context, modelName string) error { - // Use ollama run with a simple prompt that immediately exits - cmd := exec.CommandContext(ctx, "ollama", "run", modelName, "--verbose", "hi") - - // Set a shorter timeout for the run command since we just want to load the model - runCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - cmd = exec.CommandContext(runCtx, "ollama", "run", modelName, "hi") - - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to start model %s via CLI: %w", modelName, err) - } - - return nil -} - -// CLIGetModelsCount returns the number of available models using CLI -func CLIGetModelsCount(ctx context.Context) (int, error) { - models, err := CLIListModels(ctx) - if err != nil { - return 0, err - } - return len(models), nil -} - -// Performance comparison helpers - -// BenchmarkCLIvsHTTP compares CLI vs HTTP performance -func BenchmarkCLIvsHTTP(ctx context.Context) (map[string]time.Duration, error) { - results := make(map[string]time.Duration) - - // Test HTTP approach - start := time.Now() - _, err := GetModels(ctx) - if err != nil { - return nil, fmt.Errorf("HTTP GetModels failed: %w", err) - } - results["HTTP_GetModels"] = time.Since(start) - - // Test CLI approach - start = time.Now() - _, err = CLIListModels(ctx) - if err != nil { - return nil, fmt.Errorf("CLI ListModels failed: %w", err) - } - results["CLI_ListModels"] = time.Since(start) - - // Test HTTP running models - start = time.Now() - _, err = GetRunningModels(ctx) - if err != nil { - return nil, fmt.Errorf("HTTP GetRunningModels failed: %w", err) - } - results["HTTP_GetRunningModels"] = time.Since(start) - - // Test CLI running models - start = time.Now() - _, err = CLIListRunningModels(ctx) - if err != nil { - return nil, fmt.Errorf("CLI ListRunningModels failed: %w", err) - } - results["CLI_ListRunningModels"] = time.Since(start) - - return results, nil -} - -// CLICleanupProcesses provides CLI-based cleanup (alternative to HTTP-based cleanup) -func CLICleanupProcesses(ctx context.Context) error { - return CLIStopAllModels(ctx) -} diff --git a/internal/ollama/cli_test.go b/internal/ollama/cli_test.go deleted file mode 100644 index cac255fa1de3c9a438e3503912ff81e17fe972c1..0000000000000000000000000000000000000000 --- a/internal/ollama/cli_test.go +++ /dev/null @@ -1,314 +0,0 @@ -package ollama - -import ( - "context" - "testing" - "time" -) - -func TestCLIListModels(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping CLI test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - models, err := CLIListModels(ctx) - if err != nil { - t.Fatalf("Failed to list models via CLI: %v", err) - } - - t.Logf("Found %d models via CLI", len(models)) - for _, model := range models { - t.Logf(" - %s", model.Name) - } -} - -func TestCLIListRunningModels(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping CLI test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() - } - - runningModels, err := CLIListRunningModels(ctx) - if err != nil { - t.Fatalf("Failed to list running models via CLI: %v", err) - } - - t.Logf("Found %d running models via CLI", len(runningModels)) - for _, model := range runningModels { - t.Logf(" - %s", model) - } -} - -func TestCLIStopAllModels(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping CLI test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() - } - - // Get available models - models, err := GetModels(ctx) - if err != nil { - t.Fatalf("Failed to get models: %v", err) - } - - if len(models) == 0 { - t.Skip("No models available, skipping CLI stop test") - } - - // Pick a small model for testing - testModel := models[0].ID - for _, model := range models { - if model.ID == "phi3:3.8b" || model.ID == "llama3.2:3b" { - testModel = model.ID - break - } - } - - t.Logf("Testing CLI stop with model: %s", testModel) - - // Check if model is running - running, err := CLIIsModelRunning(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is running: %v", err) - } - - // If not running, start it - if !running { - t.Log("Starting model for CLI stop test...") - if err := StartModel(ctx, testModel); err != nil { - t.Fatalf("Failed to start model: %v", err) - } - - // Verify it's now running - running, err = CLIIsModelRunning(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is running after start: %v", err) - } - if !running { - t.Fatal("Model failed to start") - } - t.Log("Model started successfully") - } else { - t.Log("Model was already running") - } - - // Now test CLI stop - t.Log("Testing CLI stop all models...") - if err := CLIStopAllModels(ctx); err != nil { - t.Fatalf("Failed to stop all models via CLI: %v", err) - } - - // Give some time for models to stop - time.Sleep(2 * time.Second) - - // Check if model is still running - running, err = CLIIsModelRunning(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is running after stop: %v", err) - } - - if running { - t.Error("Model is still running after CLI stop") - } else { - t.Log("Model successfully stopped via CLI") - } -} - -func TestCLIvsHTTPPerformance(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping performance test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() - } - - results, err := BenchmarkCLIvsHTTP(ctx) - if err != nil { - t.Fatalf("Failed to benchmark CLI vs HTTP: %v", err) - } - - t.Log("Performance Comparison (CLI vs HTTP):") - for operation, duration := range results { - t.Logf(" %s: %v", operation, duration) - } - - // Compare HTTP vs CLI for model listing - httpTime := results["HTTP_GetModels"] - cliTime := results["CLI_ListModels"] - - if httpTime < cliTime { - t.Logf("HTTP is faster for listing models (%v vs %v)", httpTime, cliTime) - } else { - t.Logf("CLI is faster for listing models (%v vs %v)", cliTime, httpTime) - } - - // Compare HTTP vs CLI for running models - httpRunningTime := results["HTTP_GetRunningModels"] - cliRunningTime := results["CLI_ListRunningModels"] - - if httpRunningTime < cliRunningTime { - t.Logf("HTTP is faster for listing running models (%v vs %v)", httpRunningTime, cliRunningTime) - } else { - t.Logf("CLI is faster for listing running models (%v vs %v)", cliRunningTime, httpRunningTime) - } -} - -func TestCLICleanupVsHTTPCleanup(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping cleanup comparison test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() - } - - // Get available models - models, err := GetModels(ctx) - if err != nil { - t.Fatalf("Failed to get models: %v", err) - } - - if len(models) == 0 { - t.Skip("No models available, skipping cleanup comparison test") - } - - // Pick a small model for testing - testModel := models[0].ID - for _, model := range models { - if model.ID == "phi3:3.8b" || model.ID == "llama3.2:3b" { - testModel = model.ID - break - } - } - - t.Logf("Testing cleanup comparison with model: %s", testModel) - - // Test 1: HTTP-based cleanup - t.Log("Testing HTTP-based cleanup...") - - // Start model - if err := StartModel(ctx, testModel); err != nil { - t.Fatalf("Failed to start model: %v", err) - } - - // Verify it's loaded - loaded, err := IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded: %v", err) - } - if !loaded { - t.Fatal("Model failed to load") - } - - // Time HTTP cleanup - start := time.Now() - cleanupProcesses() - httpCleanupTime := time.Since(start) - - // Give time for cleanup - time.Sleep(2 * time.Second) - - // Check if model is still loaded - loaded, err = IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded after HTTP cleanup: %v", err) - } - - httpCleanupWorked := !loaded - - // Test 2: CLI-based cleanup - t.Log("Testing CLI-based cleanup...") - - // Start model again - if err := StartModel(ctx, testModel); err != nil { - t.Fatalf("Failed to start model for CLI test: %v", err) - } - - // Verify it's loaded - loaded, err = IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded: %v", err) - } - if !loaded { - t.Fatal("Model failed to load for CLI test") - } - - // Time CLI cleanup - start = time.Now() - if err := CLICleanupProcesses(ctx); err != nil { - t.Fatalf("CLI cleanup failed: %v", err) - } - cliCleanupTime := time.Since(start) - - // Give time for cleanup - time.Sleep(2 * time.Second) - - // Check if model is still loaded - loaded, err = IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded after CLI cleanup: %v", err) - } - - cliCleanupWorked := !loaded - - // Compare results - t.Log("Cleanup Comparison Results:") - t.Logf(" HTTP cleanup: %v (worked: %v)", httpCleanupTime, httpCleanupWorked) - t.Logf(" CLI cleanup: %v (worked: %v)", cliCleanupTime, cliCleanupWorked) - - if httpCleanupWorked && cliCleanupWorked { - if httpCleanupTime < cliCleanupTime { - t.Logf("HTTP cleanup is faster and both work") - } else { - t.Logf("CLI cleanup is faster and both work") - } - } else if httpCleanupWorked && !cliCleanupWorked { - t.Logf("HTTP cleanup works better (CLI cleanup failed)") - } else if !httpCleanupWorked && cliCleanupWorked { - t.Logf("CLI cleanup works better (HTTP cleanup failed)") - } else { - t.Logf("Both cleanup methods failed") - } -} diff --git a/internal/ollama/client.go b/internal/ollama/client.go index d6d8cd220d8034a1472b15c29856586a52dbef5c..5cfdf70a9ab9a0e345e5320111aaf7d943803008 100644 --- a/internal/ollama/client.go +++ b/internal/ollama/client.go @@ -1,148 +1,159 @@ package ollama import ( + "bytes" "context" + "encoding/json" "fmt" - "strings" - - "github.com/charmbracelet/crush/internal/fur/provider" + "net/http" ) -// IsRunning checks if Ollama is running by attempting to run a CLI command +// httpClient creates a configured HTTP client +func httpClient() *http.Client { + return &http.Client{ + Timeout: DefaultTimeout, + } +} + +// IsRunning checks if Ollama service is running func IsRunning(ctx context.Context) bool { - _, err := CLIListModels(ctx) - return err == nil + return isRunning(ctx, DefaultBaseURL) } -// GetModels retrieves available models from Ollama using CLI -func GetModels(ctx context.Context) ([]provider.Model, error) { - ollamaModels, err := CLIListModels(ctx) +// isRunning checks if Ollama is running at the specified URL +func isRunning(ctx context.Context, baseURL string) bool { + client := httpClient() + + req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/api/tags", nil) if err != nil { - return nil, err - } - - models := make([]provider.Model, len(ollamaModels)) - for i, ollamaModel := range ollamaModels { - family := extractModelFamily(ollamaModel.Name) - models[i] = provider.Model{ - ID: ollamaModel.Name, - Model: ollamaModel.Name, - CostPer1MIn: 0, // Local models have no cost - CostPer1MOut: 0, - CostPer1MInCached: 0, - CostPer1MOutCached: 0, - ContextWindow: getContextWindow(family), - DefaultMaxTokens: 4096, - CanReason: false, - HasReasoningEffort: false, - SupportsImages: supportsImages(family), - } + return false } - return models, nil + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// GetModels retrieves all available models +func GetModels(ctx context.Context) ([]Model, error) { + return getModels(ctx, DefaultBaseURL) } -// GetRunningModels returns models that are currently loaded in memory using CLI -func GetRunningModels(ctx context.Context) ([]OllamaRunningModel, error) { - runningModelNames, err := CLIListRunningModels(ctx) +// getModels retrieves models from the specified URL +func getModels(ctx context.Context, baseURL string) ([]Model, error) { + client := httpClient() + + req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/api/tags", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := client.Do(req) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to connect to Ollama: %w", err) } + defer resp.Body.Close() - var runningModels []OllamaRunningModel - for _, name := range runningModelNames { - runningModels = append(runningModels, OllamaRunningModel{ - Name: name, - }) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode) } - return runningModels, nil + var response TagsResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return response.Models, nil } -// IsModelLoaded checks if a specific model is currently loaded in memory using CLI -func IsModelLoaded(ctx context.Context, modelName string) (bool, error) { - return CLIIsModelRunning(ctx, modelName) +// GetRunningModels retrieves currently running models +func GetRunningModels(ctx context.Context) ([]RunningModel, error) { + return getRunningModels(ctx, DefaultBaseURL) } -// GetProvider returns a provider.Provider for Ollama if it's running -func GetProvider(ctx context.Context) (*provider.Provider, error) { - if !IsRunning(ctx) { - return nil, fmt.Errorf("Ollama is not running") +// getRunningModels retrieves running models from the specified URL +func getRunningModels(ctx context.Context, baseURL string) ([]RunningModel, error) { + client := httpClient() + + req, err := http.NewRequestWithContext(ctx, "GET", baseURL+"/api/ps", nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) } - models, err := GetModels(ctx) + resp, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("failed to get models: %w", err) + return nil, fmt.Errorf("failed to connect to Ollama: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode) } - return &provider.Provider{ - Name: "Ollama", - ID: "ollama", - Models: models, - }, nil + var response ProcessStatusResponse + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + return response.Models, nil } -// extractModelFamily extracts the model family from a model name -func extractModelFamily(modelName string) string { - // Extract the family from model names like "llama3.2:3b" -> "llama" - parts := strings.Split(modelName, ":") - if len(parts) > 0 { - name := parts[0] - // Handle cases like "llama3.2" -> "llama" - if strings.HasPrefix(name, "llama") { - return "llama" - } - if strings.HasPrefix(name, "mistral") { - return "mistral" - } - if strings.HasPrefix(name, "gemma") { - return "gemma" - } - if strings.HasPrefix(name, "qwen") { - return "qwen" - } - if strings.HasPrefix(name, "phi") { - return "phi" - } - if strings.HasPrefix(name, "codellama") { - return "codellama" - } - if strings.Contains(name, "llava") { - return "llava" - } - if strings.Contains(name, "vision") { - return "llama-vision" +// IsModelRunning checks if a specific model is currently running +func IsModelRunning(ctx context.Context, modelName string) (bool, error) { + runningModels, err := GetRunningModels(ctx) + if err != nil { + return false, err + } + + for _, model := range runningModels { + if model.Name == modelName { + return true, nil } } - return "unknown" + + return false, nil } -// getContextWindow returns an estimated context window based on model family -func getContextWindow(family string) int64 { - switch family { - case "llama": - return 131072 // Llama 3.x context window - case "mistral": - return 32768 - case "gemma": - return 8192 - case "qwen", "qwen2": - return 131072 - case "phi": - return 131072 - case "codellama": - return 16384 - default: - return 8192 // Conservative default - } +// LoadModel loads a model into memory by sending a simple request +func LoadModel(ctx context.Context, modelName string) error { + return loadModel(ctx, DefaultBaseURL, modelName) } -// supportsImages returns whether a model family supports image inputs -func supportsImages(family string) bool { - switch family { - case "llama-vision", "llava": - return true - default: - return false +// loadModel loads a model at the specified URL +func loadModel(ctx context.Context, baseURL, modelName string) error { + client := httpClient() + + reqBody := GenerateRequest{ + Model: modelName, + Prompt: "hi", + Stream: false, } + + reqData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/api/generate", bytes.NewBuffer(reqData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to load model: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to load model, status: %d", resp.StatusCode) + } + + return nil } diff --git a/internal/ollama/client_test.go b/internal/ollama/client_test.go index cd690ff2cfec9cea7573fa75ffa7a1e63f49a2bd..a4598afe26affa10d9e21c576728746bb8ba9084 100644 --- a/internal/ollama/client_test.go +++ b/internal/ollama/client_test.go @@ -17,13 +17,10 @@ func TestIsRunning(t *testing.T) { running := IsRunning(ctx) if running { - t.Log("Ollama is running") + t.Log("✓ Ollama is running") } else { - t.Log("Ollama is not running") + t.Log("✗ Ollama is not running") } - - // This test doesn't fail - it's informational - // The behavior depends on whether Ollama is actually running } func TestGetModels(t *testing.T) { @@ -34,13 +31,8 @@ func TestGetModels(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Ensure Ollama is running if !IsRunning(ctx) { - t.Log("Ollama is not running, attempting to start...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() + t.Skip("Ollama is not running, skipping GetModels test") } models, err := GetModels(ctx) @@ -48,10 +40,9 @@ func TestGetModels(t *testing.T) { t.Fatalf("Failed to get models: %v", err) } - t.Logf("Found %d models:", len(models)) + t.Logf("✓ Found %d models", len(models)) for _, model := range models { - t.Logf(" - %s (context: %d, max_tokens: %d)", - model.ID, model.ContextWindow, model.DefaultMaxTokens) + t.Logf(" - %s (size: %d bytes)", model.Name, model.Size) } } @@ -63,13 +54,8 @@ func TestGetRunningModels(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Ensure Ollama is running if !IsRunning(ctx) { - t.Log("Ollama is not running, attempting to start...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() + t.Skip("Ollama is not running, skipping GetRunningModels test") } runningModels, err := GetRunningModels(ctx) @@ -77,138 +63,46 @@ func TestGetRunningModels(t *testing.T) { t.Fatalf("Failed to get running models: %v", err) } - t.Logf("Found %d running models:", len(runningModels)) + t.Logf("✓ Found %d running models", len(runningModels)) for _, model := range runningModels { - t.Logf(" - %s", model.Name) + t.Logf(" - %s (size: %d bytes)", model.Name, model.Size) } } -func TestIsModelLoaded(t *testing.T) { +func TestIsModelRunning(t *testing.T) { if !IsInstalled() { - t.Skip("Ollama is not installed, skipping IsModelLoaded test") + t.Skip("Ollama is not installed, skipping IsModelRunning test") } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Ensure Ollama is running if !IsRunning(ctx) { - t.Log("Ollama is not running, attempting to start...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() + t.Skip("Ollama is not running, skipping IsModelRunning test") } - // Get available models first models, err := GetModels(ctx) if err != nil { t.Fatalf("Failed to get models: %v", err) } if len(models) == 0 { - t.Skip("No models available, skipping IsModelLoaded test") + t.Skip("No models available, skipping IsModelRunning test") } - testModel := models[0].ID - t.Logf("Testing model: %s", testModel) - - loaded, err := IsModelLoaded(ctx, testModel) + testModel := models[0].Name + running, err := IsModelRunning(ctx, testModel) if err != nil { - t.Fatalf("Failed to check if model is loaded: %v", err) + t.Fatalf("Failed to check if model is running: %v", err) } - if loaded { - t.Logf("Model %s is loaded", testModel) + if running { + t.Logf("✓ Model %s is running", testModel) } else { - t.Logf("Model %s is not loaded", testModel) - } -} - -func TestGetProvider(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping GetProvider test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Ollama is not running, attempting to start...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() - } - - provider, err := GetProvider(ctx) - if err != nil { - t.Fatalf("Failed to get provider: %v", err) - } - - if provider.Name != "Ollama" { - t.Errorf("Expected provider name to be 'Ollama', got '%s'", provider.Name) - } - - if provider.ID != "ollama" { - t.Errorf("Expected provider ID to be 'ollama', got '%s'", provider.ID) - } - - t.Logf("Provider: %s (ID: %s) with %d models", - provider.Name, provider.ID, len(provider.Models)) -} - -func TestGetContextWindow(t *testing.T) { - tests := []struct { - family string - expected int64 - }{ - {"llama", 131072}, - {"mistral", 32768}, - {"gemma", 8192}, - {"qwen", 131072}, - {"qwen2", 131072}, - {"phi", 131072}, - {"codellama", 16384}, - {"unknown", 8192}, - } - - for _, tt := range tests { - t.Run(tt.family, func(t *testing.T) { - result := getContextWindow(tt.family) - if result != tt.expected { - t.Errorf("getContextWindow(%s) = %d, expected %d", - tt.family, result, tt.expected) - } - }) - } -} - -func TestSupportsImages(t *testing.T) { - tests := []struct { - family string - expected bool - }{ - {"llama-vision", true}, - {"llava", true}, - {"llama", false}, - {"mistral", false}, - {"unknown", false}, - } - - for _, tt := range tests { - t.Run(tt.family, func(t *testing.T) { - result := supportsImages(tt.family) - if result != tt.expected { - t.Errorf("supportsImages(%s) = %v, expected %v", - tt.family, result, tt.expected) - } - }) + t.Logf("✗ Model %s is not running", testModel) } } -// Benchmark tests for client functions func BenchmarkIsRunning(b *testing.B) { if !IsInstalled() { b.Skip("Ollama is not installed") @@ -228,7 +122,6 @@ func BenchmarkGetModels(b *testing.B) { ctx := context.Background() - // Ensure Ollama is running for benchmark if !IsRunning(ctx) { b.Skip("Ollama is not running") } diff --git a/internal/ollama/ollama.go b/internal/ollama/install.go similarity index 100% rename from internal/ollama/ollama.go rename to internal/ollama/install.go diff --git a/internal/ollama/install_test.go b/internal/ollama/install_test.go new file mode 100644 index 0000000000000000000000000000000000000000..77dfec215d3a01af2fbd9eaf46af3ae9a0adedf1 --- /dev/null +++ b/internal/ollama/install_test.go @@ -0,0 +1,23 @@ +package ollama + +import ( + "testing" +) + +func TestIsInstalled(t *testing.T) { + installed := IsInstalled() + + if installed { + t.Log("✓ Ollama is installed on this system") + } else { + t.Log("✗ Ollama is not installed on this system") + } + + // This is informational - doesn't fail +} + +func BenchmarkIsInstalled(b *testing.B) { + for i := 0; i < b.N; i++ { + IsInstalled() + } +} diff --git a/internal/ollama/ollama_test.go b/internal/ollama/ollama_test.go deleted file mode 100644 index 2832aeb4527e2e924f98098ffb5fad6343d271ed..0000000000000000000000000000000000000000 --- a/internal/ollama/ollama_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package ollama - -import ( - "testing" -) - -func TestIsInstalled(t *testing.T) { - installed := IsInstalled() - - if installed { - t.Log("Ollama is installed on this system") - } else { - t.Log("Ollama is not installed on this system") - } - - // This test doesn't fail - it's informational - // In a real scenario, you might want to skip other tests if Ollama is not installed -} - -// Benchmark test for IsInstalled -func BenchmarkIsInstalled(b *testing.B) { - for i := 0; i < b.N; i++ { - IsInstalled() - } -} diff --git a/internal/ollama/process.go b/internal/ollama/process.go deleted file mode 100644 index 42e9fc6ecec5a433b8ff0e8bf9620c91e467f1bf..0000000000000000000000000000000000000000 --- a/internal/ollama/process.go +++ /dev/null @@ -1,72 +0,0 @@ -package ollama - -import ( - "context" - "os" - "os/exec" - "os/signal" - "syscall" - "time" -) - -var processManager = &ProcessManager{ - processes: make(map[string]*exec.Cmd), -} - -// setupProcessCleanup sets up signal handlers to clean up processes on exit -func setupProcessCleanup() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - - go func() { - <-c - cleanupProcesses() - os.Exit(0) - }() -} - -// cleanupProcesses terminates all Ollama processes started by Crush -func cleanupProcesses() { - processManager.mu.Lock() - defer processManager.mu.Unlock() - - // Use CLI approach to stop all running models - // This is more reliable than tracking individual processes - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - if err := CLIStopAllModels(ctx); err != nil { - // If CLI approach fails, fall back to process tracking - // Clean up model processes - for modelName, cmd := range processManager.processes { - if cmd.Process != nil { - cmd.Process.Kill() - cmd.Wait() // Wait for the process to actually exit - } - delete(processManager.processes, modelName) - } - } else { - // CLI approach succeeded, clear our process tracking - processManager.processes = make(map[string]*exec.Cmd) - } - - // Clean up Ollama server if Crush started it - if processManager.crushStartedOllama && processManager.ollamaServer != nil { - if processManager.ollamaServer.Process != nil { - // Kill the entire process group to ensure all children are terminated - syscall.Kill(-processManager.ollamaServer.Process.Pid, syscall.SIGTERM) - - // Give it a moment to shut down gracefully - time.Sleep(2 * time.Second) - - // Force kill if still running - if processManager.ollamaServer.ProcessState == nil { - syscall.Kill(-processManager.ollamaServer.Process.Pid, syscall.SIGKILL) - } - - processManager.ollamaServer.Wait() // Wait for the process to actually exit - } - processManager.ollamaServer = nil - processManager.crushStartedOllama = false - } -} diff --git a/internal/ollama/process_test.go b/internal/ollama/process_test.go deleted file mode 100644 index 52f0e594cec4f928f5ddb73edd7ac475001d514f..0000000000000000000000000000000000000000 --- a/internal/ollama/process_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package ollama - -import ( - "context" - "testing" - "time" -) - -func TestProcessManager(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping ProcessManager test") - } - - // Test that processManager is initialized - if processManager == nil { - t.Fatal("processManager is nil") - } - - if processManager.processes == nil { - t.Fatal("processManager.processes is nil") - } - - t.Log("ProcessManager is properly initialized") -} - -func TestCleanupProcesses(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping cleanup test") - } - - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) - defer cancel() - - // Start Ollama service if not running - wasRunning := IsRunning(ctx) - if !wasRunning { - t.Log("Starting Ollama service for cleanup test...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - - // Verify it started - if !IsRunning(ctx) { - t.Fatal("Failed to start Ollama service") - } - - // Test cleanup - t.Log("Testing cleanup...") - cleanupProcesses() - - // Give some time for cleanup - time.Sleep(3 * time.Second) - - // Verify cleanup worked (service should be stopped) - if IsRunning(ctx) { - t.Error("Ollama service is still running after cleanup") - } else { - t.Log("Cleanup successfully stopped Ollama service") - } - } else { - t.Log("Ollama was already running, skipping cleanup test to avoid disruption") - } -} - -func TestSetupProcessCleanup(t *testing.T) { - // Test that setupProcessCleanup can be called without panicking - // Note: We can't easily test signal handling in unit tests - defer func() { - if r := recover(); r != nil { - t.Fatalf("setupProcessCleanup panicked: %v", r) - } - }() - - // This should not panic and should be safe to call multiple times - setupProcessCleanup() - setupProcessCleanup() // Should be safe due to sync.Once - - t.Log("setupProcessCleanup completed without panic") -} - -func TestProcessManagerThreadSafety(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping thread safety test") - } - - // Test concurrent access to processManager - done := make(chan bool) - - // Start multiple goroutines that access processManager - for i := 0; i < 10; i++ { - go func() { - processManager.mu.RLock() - _ = len(processManager.processes) - processManager.mu.RUnlock() - done <- true - }() - } - - // Wait for all goroutines to complete - for i := 0; i < 10; i++ { - select { - case <-done: - // Success - case <-time.After(1 * time.Second): - t.Fatal("Thread safety test timed out") - } - } - - t.Log("ProcessManager thread safety test passed") -} diff --git a/internal/ollama/provider.go b/internal/ollama/provider.go new file mode 100644 index 0000000000000000000000000000000000000000..99a330dd26227e0a5e0a5a36b294f974dfb22698 --- /dev/null +++ b/internal/ollama/provider.go @@ -0,0 +1,127 @@ +package ollama + +import ( + "context" + "fmt" + "strings" +) + +// ProviderModel represents a model in the provider format +type ProviderModel struct { + ID string + Model string + CostPer1MIn float64 + CostPer1MOut float64 + CostPer1MInCached float64 + CostPer1MOutCached float64 + ContextWindow int64 + DefaultMaxTokens int64 + CanReason bool + HasReasoningEffort bool + SupportsImages bool +} + +// Provider represents an Ollama provider +type Provider struct { + Name string + ID string + Models []ProviderModel +} + +// GetProvider returns a Provider for Ollama +func GetProvider(ctx context.Context) (*Provider, error) { + if err := EnsureRunning(ctx); err != nil { + return nil, fmt.Errorf("failed to ensure Ollama is running: %w", err) + } + + models, err := GetModels(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get models: %w", err) + } + + providerModels := make([]ProviderModel, len(models)) + for i, model := range models { + family := extractModelFamily(model.Name) + providerModels[i] = ProviderModel{ + ID: model.Name, + Model: model.Name, + CostPer1MIn: 0, // Local models have no cost + CostPer1MOut: 0, + CostPer1MInCached: 0, + CostPer1MOutCached: 0, + ContextWindow: getContextWindow(family), + DefaultMaxTokens: 4096, + CanReason: false, + HasReasoningEffort: false, + SupportsImages: supportsImages(family), + } + } + + return &Provider{ + Name: "Ollama", + ID: "ollama", + Models: providerModels, + }, nil +} + +// extractModelFamily extracts the model family from a model name +func extractModelFamily(modelName string) string { + // Extract the family from model names like "llama3.2:3b" -> "llama" + parts := strings.Split(modelName, ":") + if len(parts) > 0 { + name := strings.ToLower(parts[0]) + + // Handle various model families in specific order + switch { + case strings.Contains(name, "llama-vision"): + return "llama-vision" + case strings.Contains(name, "codellama"): + return "codellama" + case strings.Contains(name, "llava"): + return "llava" + case strings.Contains(name, "llama"): + return "llama" + case strings.Contains(name, "mistral"): + return "mistral" + case strings.Contains(name, "gemma"): + return "gemma" + case strings.Contains(name, "qwen"): + return "qwen" + case strings.Contains(name, "phi"): + return "phi" + case strings.Contains(name, "vision"): + return "llama-vision" + } + } + return "unknown" +} + +// getContextWindow returns an estimated context window based on model family +func getContextWindow(family string) int64 { + switch family { + case "llama": + return 131072 // Llama 3.x context window + case "mistral": + return 32768 + case "gemma": + return 8192 + case "qwen": + return 131072 + case "phi": + return 131072 + case "codellama": + return 16384 + default: + return 8192 // Conservative default + } +} + +// supportsImages returns whether a model family supports image inputs +func supportsImages(family string) bool { + switch family { + case "llama-vision", "llava": + return true + default: + return false + } +} diff --git a/internal/ollama/provider_test.go b/internal/ollama/provider_test.go new file mode 100644 index 0000000000000000000000000000000000000000..06f6fbd83d469b0ffbf0bc5b77293214d6c4f9c3 --- /dev/null +++ b/internal/ollama/provider_test.go @@ -0,0 +1,120 @@ +package ollama + +import ( + "context" + "testing" + "time" +) + +func TestGetProvider(t *testing.T) { + if !IsInstalled() { + t.Skip("Ollama is not installed, skipping GetProvider test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + provider, err := GetProvider(ctx) + if err != nil { + t.Fatalf("Failed to get provider: %v", err) + } + + if provider.Name != "Ollama" { + t.Errorf("Expected provider name to be 'Ollama', got '%s'", provider.Name) + } + + if provider.ID != "ollama" { + t.Errorf("Expected provider ID to be 'ollama', got '%s'", provider.ID) + } + + t.Logf("✓ Provider: %s (ID: %s) with %d models", + provider.Name, provider.ID, len(provider.Models)) + + // Test model details + for _, model := range provider.Models { + t.Logf(" - %s (context: %d, max_tokens: %d, images: %v)", + model.ID, model.ContextWindow, model.DefaultMaxTokens, model.SupportsImages) + } + + // Cleanup + defer func() { + if processManager.crushStartedOllama { + cleanup() + } + }() +} + +func TestExtractModelFamily(t *testing.T) { + tests := []struct { + modelName string + expected string + }{ + {"llama3.2:3b", "llama"}, + {"mistral:7b", "mistral"}, + {"gemma:2b", "gemma"}, + {"qwen2.5:14b", "qwen"}, + {"phi3:3.8b", "phi"}, + {"codellama:13b", "codellama"}, + {"llava:13b", "llava"}, + {"llama-vision:7b", "llama-vision"}, + {"unknown-model:1b", "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.modelName, func(t *testing.T) { + result := extractModelFamily(tt.modelName) + if result != tt.expected { + t.Errorf("extractModelFamily(%s) = %s, expected %s", + tt.modelName, result, tt.expected) + } + }) + } +} + +func TestGetContextWindow(t *testing.T) { + tests := []struct { + family string + expected int64 + }{ + {"llama", 131072}, + {"mistral", 32768}, + {"gemma", 8192}, + {"qwen", 131072}, + {"phi", 131072}, + {"codellama", 16384}, + {"unknown", 8192}, + } + + for _, tt := range tests { + t.Run(tt.family, func(t *testing.T) { + result := getContextWindow(tt.family) + if result != tt.expected { + t.Errorf("getContextWindow(%s) = %d, expected %d", + tt.family, result, tt.expected) + } + }) + } +} + +func TestSupportsImages(t *testing.T) { + tests := []struct { + family string + expected bool + }{ + {"llama-vision", true}, + {"llava", true}, + {"llama", false}, + {"mistral", false}, + {"unknown", false}, + } + + for _, tt := range tests { + t.Run(tt.family, func(t *testing.T) { + result := supportsImages(tt.family) + if result != tt.expected { + t.Errorf("supportsImages(%s) = %v, expected %v", + tt.family, result, tt.expected) + } + }) + } +} diff --git a/internal/ollama/service.go b/internal/ollama/service.go index a603d2e80974c58ec83fa6ace7c532710fb43cac..059bec71009078c6ed43271d7d74fe80fe2e3611 100644 --- a/internal/ollama/service.go +++ b/internal/ollama/service.go @@ -8,110 +8,80 @@ import ( "time" ) -// StartOllamaService starts the Ollama service if it's not already running -func StartOllamaService(ctx context.Context) error { +var processManager = &ProcessManager{} + +// StartService starts the Ollama service if not already running +func StartService(ctx context.Context) error { if IsRunning(ctx) { return nil // Already running } - // Set up signal handling for cleanup + if !IsInstalled() { + return fmt.Errorf("Ollama is not installed") + } + + processManager.mu.Lock() + defer processManager.mu.Unlock() + + // Set up cleanup on first use processManager.setupOnce.Do(func() { - setupProcessCleanup() + setupCleanup() }) - // Start ollama serve + // Start Ollama service cmd := exec.CommandContext(ctx, "ollama", "serve") - cmd.Stdout = nil // Suppress output - cmd.Stderr = nil // Suppress errors - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setpgid: true, // Create new process group so we can kill it and all children - } + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} if err := cmd.Start(); err != nil { return fmt.Errorf("failed to start Ollama service: %w", err) } - // Store the process for cleanup - processManager.mu.Lock() - processManager.ollamaServer = cmd + processManager.ollamaProcess = cmd processManager.crushStartedOllama = true - processManager.mu.Unlock() - - // Wait for Ollama to be ready (with timeout) - timeout := time.After(10 * time.Second) - ticker := time.NewTicker(500 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-timeout: - return fmt.Errorf("timeout waiting for Ollama service to start") - case <-ticker.C: - if IsRunning(ctx) { - return nil // Ollama is now running - } - case <-ctx.Done(): - return ctx.Err() + + // Wait for service to be ready + startTime := time.Now() + for time.Since(startTime) < ServiceStartTimeout { + if IsRunning(ctx) { + return nil } + time.Sleep(100 * time.Millisecond) } -} -// StartModel starts a model using `ollama run` and keeps it loaded -func StartModel(ctx context.Context, modelName string) error { - // Check if model is already running - if loaded, err := IsModelLoaded(ctx, modelName); err != nil { - return fmt.Errorf("failed to check if model is loaded: %w", err) - } else if loaded { - return nil // Model is already running - } + return fmt.Errorf("Ollama service did not start within %v", ServiceStartTimeout) +} - // Set up signal handling for cleanup +// EnsureRunning ensures Ollama service is running, starting it if necessary +func EnsureRunning(ctx context.Context) error { + // Always ensure cleanup is set up, even if Ollama was already running processManager.setupOnce.Do(func() { - setupProcessCleanup() + setupCleanup() }) + return StartService(ctx) +} - // Start the model in the background - cmd := exec.CommandContext(ctx, "ollama", "run", modelName) - cmd.Stdin = nil // No interactive input - cmd.Stdout = nil // Suppress output - cmd.Stderr = nil // Suppress errors +// EnsureModelLoaded ensures a model is loaded, loading it if necessary +func EnsureModelLoaded(ctx context.Context, modelName string) error { + if err := EnsureRunning(ctx); err != nil { + return err + } - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start model %s: %w", modelName, err) + running, err := IsModelRunning(ctx, modelName) + if err != nil { + return fmt.Errorf("failed to check if model is running: %w", err) } - // Store the process for cleanup - processManager.mu.Lock() - processManager.processes[modelName] = cmd - processManager.mu.Unlock() - - // Wait for the model to be loaded (with timeout) - timeout := time.After(30 * time.Second) - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-timeout: - return fmt.Errorf("timeout waiting for model %s to load", modelName) - case <-ticker.C: - if loaded, err := IsModelLoaded(ctx, modelName); err != nil { - return fmt.Errorf("failed to check if model is loaded: %w", err) - } else if loaded { - return nil // Model is now running - } - case <-ctx.Done(): - return ctx.Err() - } + if running { + return nil // Already loaded } -} -// EnsureOllamaRunning ensures Ollama service is running, starting it if necessary -func EnsureOllamaRunning(ctx context.Context) error { - return StartOllamaService(ctx) -} + // Load the model + loadCtx, cancel := context.WithTimeout(ctx, ModelLoadTimeout) + defer cancel() + + if err := LoadModel(loadCtx, modelName); err != nil { + return fmt.Errorf("failed to load model %s: %w", modelName, err) + } -// EnsureModelRunning ensures a model is running, starting it if necessary -func EnsureModelRunning(ctx context.Context, modelName string) error { - return StartModel(ctx, modelName) + return nil } diff --git a/internal/ollama/service_test.go b/internal/ollama/service_test.go index 878c90c751df965125c8bf9bb7f3442b2f9b2e3e..27c9227ea8507c47b417361a59a93a7851f51c0b 100644 --- a/internal/ollama/service_test.go +++ b/internal/ollama/service_test.go @@ -6,188 +6,138 @@ import ( "time" ) -func TestStartOllamaService(t *testing.T) { +func TestStartService(t *testing.T) { if !IsInstalled() { - t.Skip("Ollama is not installed, skipping StartOllamaService test") + t.Skip("Ollama is not installed, skipping StartService test") } - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - // First check if it's already running + // Check if already running if IsRunning(ctx) { - t.Log("Ollama is already running, skipping start test") + t.Log("✓ Ollama is already running, skipping start test") return } t.Log("Starting Ollama service...") - err := StartOllamaService(ctx) + err := StartService(ctx) if err != nil { t.Fatalf("Failed to start Ollama service: %v", err) } - // Verify it's now running + // Verify it's running if !IsRunning(ctx) { - t.Fatal("Ollama service was started but IsRunning still returns false") + t.Fatal("Ollama service was started but IsRunning returns false") } - t.Log("Ollama service started successfully") + t.Log("✓ Ollama service started successfully") - // Clean up - stop the service we started - cleanupProcesses() + // Cleanup + defer func() { + if processManager.crushStartedOllama { + cleanup() + } + }() } -func TestEnsureOllamaRunning(t *testing.T) { +func TestEnsureRunning(t *testing.T) { if !IsInstalled() { - t.Skip("Ollama is not installed, skipping EnsureOllamaRunning test") + t.Skip("Ollama is not installed, skipping EnsureRunning test") } - ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() - // Test that EnsureOllamaRunning works whether Ollama is running or not - err := EnsureOllamaRunning(ctx) + err := EnsureRunning(ctx) if err != nil { - t.Fatalf("EnsureOllamaRunning failed: %v", err) + t.Fatalf("EnsureRunning failed: %v", err) } - // Verify Ollama is running if !IsRunning(ctx) { - t.Fatal("EnsureOllamaRunning succeeded but Ollama is not running") + t.Fatal("EnsureRunning succeeded but Ollama is not running") } - t.Log("EnsureOllamaRunning succeeded") + t.Log("✓ EnsureRunning succeeded") // Test calling it again when already running - err = EnsureOllamaRunning(ctx) + err = EnsureRunning(ctx) if err != nil { - t.Fatalf("EnsureOllamaRunning failed on second call: %v", err) + t.Fatalf("EnsureRunning failed on second call: %v", err) } - t.Log("EnsureOllamaRunning works when already running") + t.Log("✓ EnsureRunning is idempotent") + + // Cleanup + defer func() { + if processManager.crushStartedOllama { + cleanup() + } + }() } -func TestStartModel(t *testing.T) { +func TestEnsureModelLoaded(t *testing.T) { if !IsInstalled() { - t.Skip("Ollama is not installed, skipping StartModel test") + t.Skip("Ollama is not installed, skipping EnsureModelLoaded test") } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) defer cancel() - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() + // Get available models + if err := EnsureRunning(ctx); err != nil { + t.Fatalf("Failed to ensure Ollama is running: %v", err) } - // Get available models models, err := GetModels(ctx) if err != nil { t.Fatalf("Failed to get models: %v", err) } if len(models) == 0 { - t.Skip("No models available, skipping StartModel test") + t.Skip("No models available, skipping EnsureModelLoaded test") } - // Pick a smaller model if available, otherwise use the first one - testModel := models[0].ID + // Pick a smaller model if available + testModel := models[0].Name for _, model := range models { - if model.ID == "phi3:3.8b" || model.ID == "llama3.2:3b" { - testModel = model.ID + if model.Name == "phi3:3.8b" || model.Name == "llama3.2:3b" { + testModel = model.Name break } } t.Logf("Testing with model: %s", testModel) - // Check if model is already loaded - loaded, err := IsModelLoaded(ctx, testModel) - if err != nil { - t.Fatalf("Failed to check if model is loaded: %v", err) - } - - if loaded { - t.Log("Model is already loaded, skipping start test") - return - } - - t.Log("Starting model...") - err = StartModel(ctx, testModel) - if err != nil { - t.Fatalf("Failed to start model: %v", err) - } - - // Verify model is now loaded - loaded, err = IsModelLoaded(ctx, testModel) + err = EnsureModelLoaded(ctx, testModel) if err != nil { - t.Fatalf("Failed to check if model is loaded after start: %v", err) - } - - if !loaded { - t.Fatal("StartModel succeeded but model is not loaded") - } - - t.Log("Model started successfully") -} - -func TestEnsureModelRunning(t *testing.T) { - if !IsInstalled() { - t.Skip("Ollama is not installed, skipping EnsureModelRunning test") + t.Fatalf("Failed to ensure model is loaded: %v", err) } - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - // Ensure Ollama is running - if !IsRunning(ctx) { - t.Log("Starting Ollama service...") - if err := StartOllamaService(ctx); err != nil { - t.Fatalf("Failed to start Ollama service: %v", err) - } - defer cleanupProcesses() - } - - // Get available models - models, err := GetModels(ctx) + // Verify model is loaded + running, err := IsModelRunning(ctx, testModel) if err != nil { - t.Fatalf("Failed to get models: %v", err) + t.Fatalf("Failed to check if model is running: %v", err) } - if len(models) == 0 { - t.Skip("No models available, skipping EnsureModelRunning test") + if !running { + t.Fatal("EnsureModelLoaded succeeded but model is not running") } - testModel := models[0].ID - t.Logf("Testing with model: %s", testModel) - - // Test EnsureModelRunning - err = EnsureModelRunning(ctx, testModel) - if err != nil { - t.Fatalf("EnsureModelRunning failed: %v", err) - } + t.Log("✓ EnsureModelLoaded succeeded") - // Verify model is running - loaded, err := IsModelLoaded(ctx, testModel) + // Test calling it again when already loaded + err = EnsureModelLoaded(ctx, testModel) if err != nil { - t.Fatalf("Failed to check if model is loaded: %v", err) - } - - if !loaded { - t.Fatal("EnsureModelRunning succeeded but model is not loaded") + t.Fatalf("EnsureModelLoaded failed on second call: %v", err) } - t.Log("EnsureModelRunning succeeded") + t.Log("✓ EnsureModelLoaded is idempotent") - // Test calling it again when already running - err = EnsureModelRunning(ctx, testModel) - if err != nil { - t.Fatalf("EnsureModelRunning failed on second call: %v", err) - } - - t.Log("EnsureModelRunning works when model already running") + // Cleanup + defer func() { + if processManager.crushStartedOllama { + cleanup() + } + }() } diff --git a/internal/ollama/types.go b/internal/ollama/types.go index 3f1815d6ba1f1607af44b02ec1bd1a532e43d35b..82c8417841ae7b5f4be651104980e29175dd7347 100644 --- a/internal/ollama/types.go +++ b/internal/ollama/types.go @@ -3,25 +3,73 @@ package ollama import ( "os/exec" "sync" + "time" ) -// OllamaModel represents a model parsed from Ollama CLI output -type OllamaModel struct { - Name string - Model string - Size int64 +// Constants for configuration +const ( + DefaultBaseURL = "http://localhost:11434" + DefaultTimeout = 30 * time.Second + ServiceStartTimeout = 15 * time.Second + ModelLoadTimeout = 60 * time.Second +) + +// Model represents an Ollama model +type Model struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + Digest string `json:"digest"` + ModifiedAt time.Time `json:"modified_at"` + Details struct { + ParentModel string `json:"parent_model"` + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParameterSize string `json:"parameter_size"` + QuantizationLevel string `json:"quantization_level"` + } `json:"details"` +} + +// RunningModel represents a model currently loaded in memory +type RunningModel struct { + Name string `json:"name"` + Model string `json:"model"` + Size int64 `json:"size"` + Digest string `json:"digest"` + ExpiresAt time.Time `json:"expires_at"` + SizeVRAM int64 `json:"size_vram"` + Details struct { + ParentModel string `json:"parent_model"` + Format string `json:"format"` + Family string `json:"family"` + Families []string `json:"families"` + ParameterSize string `json:"parameter_size"` + QuantizationLevel string `json:"quantization_level"` + } `json:"details"` +} + +// TagsResponse represents the response from /api/tags +type TagsResponse struct { + Models []Model `json:"models"` +} + +// ProcessStatusResponse represents the response from /api/ps +type ProcessStatusResponse struct { + Models []RunningModel `json:"models"` } -// OllamaRunningModel represents a model that is currently loaded in memory -type OllamaRunningModel struct { - Name string +// GenerateRequest represents a request to /api/generate +type GenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream"` } // ProcessManager manages Ollama processes started by Crush type ProcessManager struct { mu sync.RWMutex - processes map[string]*exec.Cmd - ollamaServer *exec.Cmd // The main Ollama server process + ollamaProcess *exec.Cmd + crushStartedOllama bool setupOnce sync.Once - crushStartedOllama bool // Track if Crush started the Ollama service } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 8fe4d125557fccc8fab7436df1830c6f6498b7d5..8df513e41a813e83910f4a13a0f5e324ef6bcd6b 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -179,7 +179,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // If this is an Ollama model, ensure it's running if msg.Model.Provider == "ollama" { - if err := ollama.EnsureModelRunning(context.Background(), msg.Model.Model); err != nil { + if err := ollama.EnsureModelLoaded(context.Background(), msg.Model.Model); err != nil { return a, util.ReportError(fmt.Errorf("failed to start Ollama model %s: %v", msg.Model.Model, err)) } }