diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 3f31335dabb4802b5811947be8c5c0b9b8a03fc4..23fa331a0849940460b432d3c34d53fa9194e923 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -19,6 +19,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/message" ) @@ -77,6 +78,12 @@ func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) a } else if hasBearerAuth { slog.Debug("Skipping X-Api-Key header because Authorization header is provided") } + + if config.Get().Options.Debug { + httpClient := log.NewHTTPClient() + anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(httpClient)) + } + switch tp { case AnthropicClientTypeBedrock: anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) @@ -271,17 +278,11 @@ func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, to } func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { - cfg := config.Get() - attempts := 0 for { attempts++ // Prepare messages on each attempt in case max_tokens was adjusted preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) - if cfg.Options.Debug { - jsonData, _ := json.Marshal(preparedMessages) - slog.Debug("Prepared messages", "messages", string(jsonData)) - } var opts []option.RequestOption if a.isThinkingEnabled() { @@ -294,7 +295,7 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, ) // If there is an error we are going to see if we can retry the call if err != nil { - slog.Error("Error in Anthropic API call", "error", err) + slog.Error("Anthropic API error", "error", err.Error(), "attempt", attempts, "max_retries", maxRetries) retry, after, retryErr := a.shouldRetry(attempts, err) if retryErr != nil { return nil, retryErr @@ -327,7 +328,6 @@ func (a *anthropicClient) send(ctx context.Context, messages []message.Message, } func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent { - cfg := config.Get() attempts := 0 eventChan := make(chan ProviderEvent) go func() { @@ -335,10 +335,6 @@ func (a *anthropicClient) stream(ctx context.Context, messages []message.Message attempts++ // Prepare messages on each attempt in case max_tokens was adjusted preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools)) - if cfg.Options.Debug { - jsonData, _ := json.Marshal(preparedMessages) - slog.Debug("Prepared messages", "messages", string(jsonData)) - } var opts []option.RequestOption if a.isThinkingEnabled() { diff --git a/internal/llm/provider/azure.go b/internal/llm/provider/azure.go index 31d06bd1b040d8f8cce3afa28fad53b0fe12eaa3..9042d66876c6f22bd9c06a5f52f6b4502e32c0f2 100644 --- a/internal/llm/provider/azure.go +++ b/internal/llm/provider/azure.go @@ -1,6 +1,8 @@ package provider import ( + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/log" "github.com/openai/openai-go" "github.com/openai/openai-go/azure" "github.com/openai/openai-go/option" @@ -22,6 +24,11 @@ func newAzureClient(opts providerClientOptions) AzureClient { azure.WithEndpoint(opts.baseURL, apiVersion), } + if config.Get().Options.Debug { + httpClient := log.NewHTTPClient() + reqOpts = append(reqOpts, option.WithHTTPClient(httpClient)) + } + reqOpts = append(reqOpts, azure.WithAPIKey(opts.apiKey)) base := &openaiClient{ providerOptions: opts, diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 0070d246012547a691f8c6a8cbd8de2234cd93ec..c7625670e35933597915ee73307f5956a9452814 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -13,6 +13,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/message" "github.com/google/uuid" "google.golang.org/genai" @@ -39,7 +40,14 @@ func newGeminiClient(opts providerClientOptions) GeminiClient { } func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI}) + cc := &genai.ClientConfig{ + APIKey: opts.apiKey, + Backend: genai.BackendGeminiAPI, + } + if config.Get().Options.Debug { + cc.HTTPClient = log.NewHTTPClient() + } + client, err := genai.NewClient(context.Background(), cc) if err != nil { return nil, err } @@ -166,10 +174,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too geminiMessages := g.convertMessages(messages) model := g.providerOptions.model(g.providerOptions.modelType) cfg := config.Get() - if cfg.Options.Debug { - jsonData, _ := json.Marshal(geminiMessages) - slog.Debug("Prepared messages", "messages", string(jsonData)) - } modelConfig := cfg.Models[config.SelectedModelTypeLarge] if g.providerOptions.modelType == config.SelectedModelTypeSmall { @@ -266,10 +270,6 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t model := g.providerOptions.model(g.providerOptions.modelType) cfg := config.Get() - if cfg.Options.Debug { - jsonData, _ := json.Marshal(geminiMessages) - slog.Debug("Prepared messages", "messages", string(jsonData)) - } modelConfig := cfg.Models[config.SelectedModelTypeLarge] if g.providerOptions.modelType == config.SelectedModelTypeSmall { diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index ee3a3113a001bbc31efcda867f1f1c62ae161173..b001353c9d94acebdf3eba9707c1525b65a38098 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -2,7 +2,6 @@ package provider import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -13,6 +12,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" + "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/message" "github.com/openai/openai-go" "github.com/openai/openai-go/option" @@ -46,6 +46,11 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { } } + if config.Get().Options.Debug { + httpClient := log.NewHTTPClient() + openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient)) + } + for key, value := range opts.extraHeaders { openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value)) } @@ -250,11 +255,6 @@ func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessagePar func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) { params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools)) - cfg := config.Get() - if cfg.Options.Debug { - jsonData, _ := json.Marshal(params) - slog.Debug("Prepared messages", "messages", string(jsonData)) - } attempts := 0 for { attempts++ @@ -311,12 +311,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t IncludeUsage: openai.Bool(true), } - cfg := config.Get() - if cfg.Options.Debug { - jsonData, _ := json.Marshal(params) - slog.Debug("Prepared messages", "messages", string(jsonData)) - } - attempts := 0 eventChan := make(chan ProviderEvent) @@ -420,11 +414,6 @@ func (o *openaiClient) stream(ctx context.Context, messages []message.Message, t err := openaiStream.Err() if err == nil || errors.Is(err, io.EOF) { - if cfg.Options.Debug { - jsonData, _ := json.Marshal(acc.ChatCompletion) - slog.Debug("Response", "messages", string(jsonData)) - } - if len(acc.Choices) == 0 { eventChan <- ProviderEvent{ Type: EventError, @@ -525,7 +514,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) slog.Warn("Retry-After header", "values", retryAfterValues) } } else { - slog.Warn("OpenAI API error", "error", err.Error()) + slog.Error("OpenAI API error", "error", err.Error(), "attempt", attempts, "max_retries", maxRetries) } backoffMs := 2000 * (1 << (attempts - 1)) diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index cbc86d8b7428639ea89ad49771ed4515d18adc07..871ff092b058af70833ba615260efcdbc09f2514 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -5,6 +5,8 @@ import ( "log/slog" "strings" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/log" "google.golang.org/genai" ) @@ -13,11 +15,15 @@ type VertexAIClient ProviderClient func newVertexAIClient(opts providerClientOptions) VertexAIClient { project := opts.extraParams["project"] location := opts.extraParams["location"] - client, err := genai.NewClient(context.Background(), &genai.ClientConfig{ + cc := &genai.ClientConfig{ Project: project, Location: location, Backend: genai.BackendVertexAI, - }) + } + if config.Get().Options.Debug { + cc.HTTPClient = log.NewHTTPClient() + } + client, err := genai.NewClient(context.Background(), cc) if err != nil { slog.Error("Failed to create VertexAI client", "error", err) return nil diff --git a/internal/log/http.go b/internal/log/http.go new file mode 100644 index 0000000000000000000000000000000000000000..1091e5706c09be374e6775f8906c91505e10b33f --- /dev/null +++ b/internal/log/http.go @@ -0,0 +1,125 @@ +package log + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "strings" + "time" +) + +// NewHTTPClient creates an HTTP client with debug logging enabled when debug mode is on. +func NewHTTPClient() *http.Client { + if !slog.Default().Enabled(context.TODO(), slog.LevelDebug) { + return http.DefaultClient + } + return &http.Client{ + Transport: &HTTPRoundTripLogger{ + Transport: http.DefaultTransport, + }, + } +} + +// HTTPRoundTripLogger is an http.RoundTripper that logs requests and responses. +type HTTPRoundTripLogger struct { + Transport http.RoundTripper +} + +// RoundTrip implements http.RoundTripper interface with logging. +func (h *HTTPRoundTripLogger) RoundTrip(req *http.Request) (*http.Response, error) { + var err error + var save io.ReadCloser + save, req.Body, err = drainBody(req.Body) + if err != nil { + slog.Error( + "HTTP request failed", + "method", req.Method, + "url", req.URL, + "error", err, + ) + return nil, err + } + + slog.Debug( + "HTTP Request", + "method", req.Method, + "url", req.URL, + "body", bodyToString(save), + ) + + start := time.Now() + resp, err := h.Transport.RoundTrip(req) + duration := time.Since(start) + if err != nil { + slog.Error( + "HTTP request failed", + "method", req.Method, + "url", req.URL, + "duration_ms", duration.Milliseconds(), + "error", err, + ) + return resp, err + } + + save, resp.Body, err = drainBody(resp.Body) + slog.Debug( + "HTTP Response", + "status_code", resp.StatusCode, + "status", resp.Status, + "headers", formatHeaders(resp.Header), + "body", bodyToString(save), + "content_length", resp.ContentLength, + "duration_ms", duration.Milliseconds(), + "error", err, + ) + return resp, err +} + +func bodyToString(body io.ReadCloser) string { + src, err := io.ReadAll(body) + if err != nil { + slog.Error("Failed to read body", "error", err) + return "" + } + var b bytes.Buffer + if json.Compact(&b, bytes.TrimSpace(src)) != nil { + // not json probably + return string(src) + } + return b.String() +} + +// formatHeaders formats HTTP headers for logging, filtering out sensitive information. +func formatHeaders(headers http.Header) map[string][]string { + filtered := make(map[string][]string) + for key, values := range headers { + lowerKey := strings.ToLower(key) + // Filter out sensitive headers + if strings.Contains(lowerKey, "authorization") || + strings.Contains(lowerKey, "api-key") || + strings.Contains(lowerKey, "token") || + strings.Contains(lowerKey, "secret") { + filtered[key] = []string{"[REDACTED]"} + } else { + filtered[key] = values + } + } + return filtered +} + +func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + if b == nil || b == http.NoBody { + return http.NoBody, http.NoBody, nil + } + var buf bytes.Buffer + if _, err = buf.ReadFrom(b); err != nil { + return nil, b, err + } + if err = b.Close(); err != nil { + return nil, b, err + } + return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil +} diff --git a/internal/log/http_test.go b/internal/log/http_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cf5a21185dce2ec2e84e1979d58047eb359103db --- /dev/null +++ b/internal/log/http_test.go @@ -0,0 +1,73 @@ +package log + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestHTTPRoundTripLogger(t *testing.T) { + // Create a test server that returns a 500 error + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Custom-Header", "test-value") + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "Internal server error", "code": 500}`)) + })) + defer server.Close() + + // Create HTTP client with logging + client := NewHTTPClient() + + // Make a request + req, err := http.NewRequestWithContext( + t.Context(), + http.MethodPost, + server.URL, + strings.NewReader(`{"test": "data"}`), + ) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer secret-token") + + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + // Verify response + if resp.StatusCode != http.StatusInternalServerError { + t.Errorf("Expected status code 500, got %d", resp.StatusCode) + } +} + +func TestFormatHeaders(t *testing.T) { + headers := http.Header{ + "Content-Type": []string{"application/json"}, + "Authorization": []string{"Bearer secret-token"}, + "X-API-Key": []string{"api-key-123"}, + "User-Agent": []string{"test-agent"}, + } + + formatted := formatHeaders(headers) + + // Check that sensitive headers are redacted + if formatted["Authorization"][0] != "[REDACTED]" { + t.Error("Authorization header should be redacted") + } + if formatted["X-API-Key"][0] != "[REDACTED]" { + t.Error("X-API-Key header should be redacted") + } + + // Check that non-sensitive headers are preserved + if formatted["Content-Type"][0] != "application/json" { + t.Error("Content-Type header should be preserved") + } + if formatted["User-Agent"][0] != "test-agent" { + t.Error("User-Agent header should be preserved") + } +}