diff --git a/internal/silverbullet/client.go b/internal/silverbullet/client.go index 80897a069d50800cb4efca77f3052e62d60f83de..cc7e3803a78e120d55f5ba37a788529fb97dc88d 100644 --- a/internal/silverbullet/client.go +++ b/internal/silverbullet/client.go @@ -12,9 +12,11 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/url" "strconv" "strings" + "sync" "time" ) @@ -33,18 +35,18 @@ type Client struct { baseURL string httpClient *http.Client auth Auth - // cached session from password login - cachedCookieName string - cachedCookieVal string + authMu sync.Mutex } // New creates a new SilverBullet client. // The baseURL is normalized to remove trailing slashes. func New(baseURL string, auth Auth) *Client { + jar, _ := cookiejar.New(nil) return &Client{ baseURL: strings.TrimRight(baseURL, "/"), httpClient: &http.Client{ Timeout: 6 * time.Hour, // long enough for lua_script with X-Timeout up to 21600s + Jar: jar, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, @@ -92,11 +94,8 @@ func (c *Client) ExecuteLua(ctx context.Context, script string, timeout int) (*L req.Header.Set("Content-Type", "text/plain") req.Header.Set("X-Timeout", strconv.Itoa(timeout)) - if err := c.setAuth(req); err != nil { - return nil, fmt.Errorf("setting auth: %w", err) - } - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { return nil, fmt.Errorf("executing lua: %w", err) } @@ -135,11 +134,7 @@ func (c *Client) Screenshot(ctx context.Context) ([]byte, error) { return nil, fmt.Errorf("creating request: %w", err) } - if err := c.setAuth(req); err != nil { - return nil, fmt.Errorf("setting auth: %w", err) - } - - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { return nil, fmt.Errorf("fetching screenshot: %w", err) } @@ -179,10 +174,6 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs return nil, fmt.Errorf("creating request: %w", err) } - if err := c.setAuth(req); err != nil { - return nil, fmt.Errorf("setting auth: %w", err) - } - q := req.URL.Query() q.Set("limit", strconv.Itoa(limit)) if since > 0 { @@ -190,7 +181,7 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs } req.URL.RawQuery = q.Encode() - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { return nil, fmt.Errorf("fetching logs: %w", err) } @@ -218,9 +209,62 @@ func (c *Client) resolveURL(path string) (string, error) { return url.JoinPath(c.baseURL, path) } +// do sends an authenticated request. +// If a password session is rejected, it discards the cached cookie, logs in again, +// and retries the request once. +func (c *Client) do(req *http.Request) (*http.Response, error) { + if err := c.setAuth(req); err != nil { + return nil, fmt.Errorf("setting auth: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + if !c.usesPasswordAuth() || !isAuthFailure(resp.StatusCode) { + return resp, nil + } + + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + if req.Body != nil && req.GetBody == nil { + return nil, fmt.Errorf("authentication failed and request body cannot be replayed") + } + if err := c.expireSessionCookies(req.URL); err != nil { + return nil, fmt.Errorf("expiring session cookie: %w", err) + } + + retryReq := req.Clone(req.Context()) + retryReq.Header = req.Header.Clone() + retryReq.Header.Del("Cookie") + if req.GetBody != nil { + retryBody, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("recreating request body: %w", err) + } + retryReq.Body = retryBody + } + + if err := c.setAuth(retryReq); err != nil { + return nil, fmt.Errorf("refreshing auth: %w", err) + } + + return c.httpClient.Do(retryReq) +} + +func (c *Client) usesPasswordAuth() bool { + return c.auth.User != "" && c.auth.Pass != "" +} + +func isAuthFailure(status int) bool { + return status == http.StatusUnauthorized || (status >= 300 && status < 400) +} + // setAuth sets authentication headers on the request. // Bearer token goes on the Authorization header. -// Password auth logs in via POST /.auth and sends the session cookie. +// Password auth logs in via POST /.auth; the client's cookie jar sends the +// resulting session cookie on subsequent requests until it expires. // Both can coexist — they use different headers. func (c *Client) setAuth(req *http.Request) error { // Bearer token on Authorization header @@ -229,7 +273,7 @@ func (c *Client) setAuth(req *http.Request) error { } // Password auth via session cookie - if c.auth.User != "" && c.auth.Pass != "" { + if c.usesPasswordAuth() { if err := c.ensureSessionCookie(req); err != nil { return fmt.Errorf("session login: %w", err) } @@ -238,27 +282,29 @@ func (c *Client) setAuth(req *http.Request) error { return nil } -// ensureSessionCookie logs in via POST /.auth if needed and sets the session cookie. +// ensureSessionCookie logs in via POST /.auth if the cookie jar has no current +// auth cookie. The jar applies normal cookie expiration rules. func (c *Client) ensureSessionCookie(req *http.Request) error { - // Use cached session if available - if c.cachedCookieName != "" && c.cachedCookieVal != "" { - req.AddCookie(&http.Cookie{ - Name: c.cachedCookieName, - Value: c.cachedCookieVal, - }) + loginURL, err := c.resolveURL("/.auth") + if err != nil { + return fmt.Errorf("resolving login URL: %w", err) + } + loginCookieURL, err := url.Parse(loginURL) + if err != nil { + return fmt.Errorf("parsing login URL: %w", err) + } + + c.authMu.Lock() + defer c.authMu.Unlock() + + if c.hasAuthCookie(loginCookieURL) { return nil } - // Log in to get session cookie form := url.Values{} form.Set("username", c.auth.User) form.Set("password", c.auth.Pass) - loginURL, err := c.resolveURL("/.auth") - if err != nil { - return fmt.Errorf("resolving login URL: %w", err) - } - loginReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, loginURL, strings.NewReader(form.Encode())) if err != nil { return fmt.Errorf("creating login request: %w", err) @@ -274,19 +320,43 @@ func (c *Client) ensureSessionCookie(req *http.Request) error { if loginResp.StatusCode != http.StatusOK { return fmt.Errorf("login failed with status %d", loginResp.StatusCode) } + if !c.hasAuthCookie(loginCookieURL) { + return fmt.Errorf("login succeeded but no auth cookie returned") + } - // Extract session cookie from Set-Cookie header - for _, cookie := range loginResp.Cookies() { - if strings.HasPrefix(cookie.Name, "auth_") { - c.cachedCookieName = cookie.Name - c.cachedCookieVal = cookie.Value - req.AddCookie(&http.Cookie{ - Name: cookie.Name, - Value: cookie.Value, - }) - return nil + return nil +} + +func (c *Client) hasAuthCookie(u *url.URL) bool { + if c.httpClient.Jar == nil { + return false + } + for _, cookie := range c.httpClient.Jar.Cookies(u) { + if strings.HasPrefix(cookie.Name, "auth_") && cookie.Value != "" { + return true } } + return false +} + +func (c *Client) expireSessionCookies(u *url.URL) error { + if c.httpClient.Jar == nil { + return nil + } + + c.authMu.Lock() + defer c.authMu.Unlock() - return fmt.Errorf("login succeeded but no auth cookie returned") + cookies := c.httpClient.Jar.Cookies(u) + for _, cookie := range cookies { + if strings.HasPrefix(cookie.Name, "auth_") { + c.httpClient.Jar.SetCookies(u, []*http.Cookie{{ + Name: cookie.Name, + Value: "", + Path: "/", + MaxAge: -1, + }}) + } + } + return nil } diff --git a/internal/silverbullet/client_test.go b/internal/silverbullet/client_test.go index db6d3f593ef67a2f4c5e31a36385e685b1cc2cc2..cab9ab3122a95d1d13e9463c863a570db79a2c8f 100644 --- a/internal/silverbullet/client_test.go +++ b/internal/silverbullet/client_test.go @@ -10,7 +10,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" + "time" ) func newTestServer(mux *http.ServeMux) *httptest.Server { @@ -312,6 +314,131 @@ func TestBothAuth(t *testing.T) { } } +func TestPasswordAuthReloginsAfterCookieExpires(t *testing.T) { + mux := http.NewServeMux() + var mu sync.Mutex + loginCount := 0 + currentToken := "" + + mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + mu.Lock() + loginCount++ + currentToken = fmt.Sprintf("mock-jwt-token-%d", loginCount) + token := currentToken + mu.Unlock() + + http.SetCookie(w, &http.Cookie{ + Name: "auth_session", + Value: token, + MaxAge: 1, + }) + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/.runtime/logs", func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("auth_session") + if err != nil { + t.Error("expected auth_session cookie, got none") + w.WriteHeader(http.StatusUnauthorized) + return + } + + mu.Lock() + expectedToken := currentToken + mu.Unlock() + if cookie.Value != expectedToken { + t.Errorf("expected cookie value %q, got %q", expectedToken, cookie.Value) + } + _ = json.NewEncoder(w).Encode(LogsResult{Logs: []LogEntry{}}) + }) + + srv := newTestServer(mux) + defer srv.Close() + + client := testClientWithBasicAuth(srv.URL) + if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil { + t.Fatalf("first ConsoleLogs failed: %v", err) + } + time.Sleep(1100 * time.Millisecond) + if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil { + t.Fatalf("second ConsoleLogs failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if loginCount != 2 { + t.Fatalf("expected 2 logins after cookie expiry, got %d", loginCount) + } +} + +func TestPasswordAuthRefreshesRejectedSession(t *testing.T) { + mux := http.NewServeMux() + var mu sync.Mutex + loginCount := 0 + + mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + mu.Lock() + loginCount++ + token := fmt.Sprintf("mock-jwt-token-%d", loginCount) + mu.Unlock() + + http.SetCookie(w, &http.Cookie{ + Name: "auth_session", + Value: token, + }) + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/.runtime/lua_script", func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("auth_session") + if err != nil { + t.Error("expected auth_session cookie, got none") + w.WriteHeader(http.StatusUnauthorized) + return + } + if cookie.Value == "mock-jwt-token-1" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if cookie.Value != "mock-jwt-token-2" { + t.Errorf("expected refreshed cookie value, got %q", cookie.Value) + } + + result := LuaResult{Result: json.RawMessage(`"ok"`)} + _ = json.NewEncoder(w).Encode(result) + }) + + srv := newTestServer(mux) + defer srv.Close() + + client := testClientWithBasicAuth(srv.URL) + if _, err := client.ExecuteLua(context.Background(), "return 'ok'", 30); err != nil { + t.Fatalf("ExecuteLua failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if loginCount != 2 { + t.Fatalf("expected 2 logins after rejected session, got %d", loginCount) + } +} + func TestTrailingSlashURL(t *testing.T) { // Verify that trailing slashes in the base URL are handled correctly client := New("http://localhost:3000/", Auth{})