From 6608b4f0f29f087961a530a32589e244eea84342 Mon Sep 17 00:00:00 2001 From: Amolith Date: Sun, 31 May 2026 15:35:04 -0600 Subject: [PATCH] client: refresh session cookies Use the standard HTTP cookie jar for password sessions so cookie expiry is handled by normal cookie rules. Retry once after an auth failure by expiring the stale session cookie and logging in again from the configured credentials. Add tests covering cookie expiry and server-side rejection of stale sessions. --- internal/silverbullet/client.go | 156 +++++++++++++++++++-------- internal/silverbullet/client_test.go | 127 ++++++++++++++++++++++ 2 files changed, 240 insertions(+), 43 deletions(-) diff --git a/internal/silverbullet/client.go b/internal/silverbullet/client.go index 80897a069d50800cb4efca77f3052e62d60f83de..cc7e3803a78e120d55f5ba37a788529fb97dc88d 100644 --- a/internal/silverbullet/client.go +++ b/internal/silverbullet/client.go @@ -12,9 +12,11 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/url" "strconv" "strings" + "sync" "time" ) @@ -33,18 +35,18 @@ type Client struct { baseURL string httpClient *http.Client auth Auth - // cached session from password login - cachedCookieName string - cachedCookieVal string + authMu sync.Mutex } // New creates a new SilverBullet client. // The baseURL is normalized to remove trailing slashes. func New(baseURL string, auth Auth) *Client { + jar, _ := cookiejar.New(nil) return &Client{ baseURL: strings.TrimRight(baseURL, "/"), httpClient: &http.Client{ Timeout: 6 * time.Hour, // long enough for lua_script with X-Timeout up to 21600s + Jar: jar, CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, @@ -92,11 +94,8 @@ func (c *Client) ExecuteLua(ctx context.Context, script string, timeout int) (*L req.Header.Set("Content-Type", "text/plain") req.Header.Set("X-Timeout", strconv.Itoa(timeout)) - if err := c.setAuth(req); err != nil { - return nil, fmt.Errorf("setting auth: %w", err) - } - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { return nil, fmt.Errorf("executing lua: %w", err) } @@ -135,11 +134,7 @@ func (c *Client) Screenshot(ctx context.Context) ([]byte, error) { return nil, fmt.Errorf("creating request: %w", err) } - if err := c.setAuth(req); err != nil { - return nil, fmt.Errorf("setting auth: %w", err) - } - - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { return nil, fmt.Errorf("fetching screenshot: %w", err) } @@ -179,10 +174,6 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs return nil, fmt.Errorf("creating request: %w", err) } - if err := c.setAuth(req); err != nil { - return nil, fmt.Errorf("setting auth: %w", err) - } - q := req.URL.Query() q.Set("limit", strconv.Itoa(limit)) if since > 0 { @@ -190,7 +181,7 @@ func (c *Client) ConsoleLogs(ctx context.Context, limit int, since int64) (*Logs } req.URL.RawQuery = q.Encode() - resp, err := c.httpClient.Do(req) + resp, err := c.do(req) if err != nil { return nil, fmt.Errorf("fetching logs: %w", err) } @@ -218,9 +209,62 @@ func (c *Client) resolveURL(path string) (string, error) { return url.JoinPath(c.baseURL, path) } +// do sends an authenticated request. +// If a password session is rejected, it discards the cached cookie, logs in again, +// and retries the request once. +func (c *Client) do(req *http.Request) (*http.Response, error) { + if err := c.setAuth(req); err != nil { + return nil, fmt.Errorf("setting auth: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + if !c.usesPasswordAuth() || !isAuthFailure(resp.StatusCode) { + return resp, nil + } + + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + + if req.Body != nil && req.GetBody == nil { + return nil, fmt.Errorf("authentication failed and request body cannot be replayed") + } + if err := c.expireSessionCookies(req.URL); err != nil { + return nil, fmt.Errorf("expiring session cookie: %w", err) + } + + retryReq := req.Clone(req.Context()) + retryReq.Header = req.Header.Clone() + retryReq.Header.Del("Cookie") + if req.GetBody != nil { + retryBody, err := req.GetBody() + if err != nil { + return nil, fmt.Errorf("recreating request body: %w", err) + } + retryReq.Body = retryBody + } + + if err := c.setAuth(retryReq); err != nil { + return nil, fmt.Errorf("refreshing auth: %w", err) + } + + return c.httpClient.Do(retryReq) +} + +func (c *Client) usesPasswordAuth() bool { + return c.auth.User != "" && c.auth.Pass != "" +} + +func isAuthFailure(status int) bool { + return status == http.StatusUnauthorized || (status >= 300 && status < 400) +} + // setAuth sets authentication headers on the request. // Bearer token goes on the Authorization header. -// Password auth logs in via POST /.auth and sends the session cookie. +// Password auth logs in via POST /.auth; the client's cookie jar sends the +// resulting session cookie on subsequent requests until it expires. // Both can coexist — they use different headers. func (c *Client) setAuth(req *http.Request) error { // Bearer token on Authorization header @@ -229,7 +273,7 @@ func (c *Client) setAuth(req *http.Request) error { } // Password auth via session cookie - if c.auth.User != "" && c.auth.Pass != "" { + if c.usesPasswordAuth() { if err := c.ensureSessionCookie(req); err != nil { return fmt.Errorf("session login: %w", err) } @@ -238,27 +282,29 @@ func (c *Client) setAuth(req *http.Request) error { return nil } -// ensureSessionCookie logs in via POST /.auth if needed and sets the session cookie. +// ensureSessionCookie logs in via POST /.auth if the cookie jar has no current +// auth cookie. The jar applies normal cookie expiration rules. func (c *Client) ensureSessionCookie(req *http.Request) error { - // Use cached session if available - if c.cachedCookieName != "" && c.cachedCookieVal != "" { - req.AddCookie(&http.Cookie{ - Name: c.cachedCookieName, - Value: c.cachedCookieVal, - }) + loginURL, err := c.resolveURL("/.auth") + if err != nil { + return fmt.Errorf("resolving login URL: %w", err) + } + loginCookieURL, err := url.Parse(loginURL) + if err != nil { + return fmt.Errorf("parsing login URL: %w", err) + } + + c.authMu.Lock() + defer c.authMu.Unlock() + + if c.hasAuthCookie(loginCookieURL) { return nil } - // Log in to get session cookie form := url.Values{} form.Set("username", c.auth.User) form.Set("password", c.auth.Pass) - loginURL, err := c.resolveURL("/.auth") - if err != nil { - return fmt.Errorf("resolving login URL: %w", err) - } - loginReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, loginURL, strings.NewReader(form.Encode())) if err != nil { return fmt.Errorf("creating login request: %w", err) @@ -274,19 +320,43 @@ func (c *Client) ensureSessionCookie(req *http.Request) error { if loginResp.StatusCode != http.StatusOK { return fmt.Errorf("login failed with status %d", loginResp.StatusCode) } + if !c.hasAuthCookie(loginCookieURL) { + return fmt.Errorf("login succeeded but no auth cookie returned") + } - // Extract session cookie from Set-Cookie header - for _, cookie := range loginResp.Cookies() { - if strings.HasPrefix(cookie.Name, "auth_") { - c.cachedCookieName = cookie.Name - c.cachedCookieVal = cookie.Value - req.AddCookie(&http.Cookie{ - Name: cookie.Name, - Value: cookie.Value, - }) - return nil + return nil +} + +func (c *Client) hasAuthCookie(u *url.URL) bool { + if c.httpClient.Jar == nil { + return false + } + for _, cookie := range c.httpClient.Jar.Cookies(u) { + if strings.HasPrefix(cookie.Name, "auth_") && cookie.Value != "" { + return true } } + return false +} + +func (c *Client) expireSessionCookies(u *url.URL) error { + if c.httpClient.Jar == nil { + return nil + } + + c.authMu.Lock() + defer c.authMu.Unlock() - return fmt.Errorf("login succeeded but no auth cookie returned") + cookies := c.httpClient.Jar.Cookies(u) + for _, cookie := range cookies { + if strings.HasPrefix(cookie.Name, "auth_") { + c.httpClient.Jar.SetCookies(u, []*http.Cookie{{ + Name: cookie.Name, + Value: "", + Path: "/", + MaxAge: -1, + }}) + } + } + return nil } diff --git a/internal/silverbullet/client_test.go b/internal/silverbullet/client_test.go index db6d3f593ef67a2f4c5e31a36385e685b1cc2cc2..cab9ab3122a95d1d13e9463c863a570db79a2c8f 100644 --- a/internal/silverbullet/client_test.go +++ b/internal/silverbullet/client_test.go @@ -10,7 +10,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" + "time" ) func newTestServer(mux *http.ServeMux) *httptest.Server { @@ -312,6 +314,131 @@ func TestBothAuth(t *testing.T) { } } +func TestPasswordAuthReloginsAfterCookieExpires(t *testing.T) { + mux := http.NewServeMux() + var mu sync.Mutex + loginCount := 0 + currentToken := "" + + mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + mu.Lock() + loginCount++ + currentToken = fmt.Sprintf("mock-jwt-token-%d", loginCount) + token := currentToken + mu.Unlock() + + http.SetCookie(w, &http.Cookie{ + Name: "auth_session", + Value: token, + MaxAge: 1, + }) + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/.runtime/logs", func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("auth_session") + if err != nil { + t.Error("expected auth_session cookie, got none") + w.WriteHeader(http.StatusUnauthorized) + return + } + + mu.Lock() + expectedToken := currentToken + mu.Unlock() + if cookie.Value != expectedToken { + t.Errorf("expected cookie value %q, got %q", expectedToken, cookie.Value) + } + _ = json.NewEncoder(w).Encode(LogsResult{Logs: []LogEntry{}}) + }) + + srv := newTestServer(mux) + defer srv.Close() + + client := testClientWithBasicAuth(srv.URL) + if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil { + t.Fatalf("first ConsoleLogs failed: %v", err) + } + time.Sleep(1100 * time.Millisecond) + if _, err := client.ConsoleLogs(context.Background(), 1, 0); err != nil { + t.Fatalf("second ConsoleLogs failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if loginCount != 2 { + t.Fatalf("expected 2 logins after cookie expiry, got %d", loginCount) + } +} + +func TestPasswordAuthRefreshesRejectedSession(t *testing.T) { + mux := http.NewServeMux() + var mu sync.Mutex + loginCount := 0 + + mux.HandleFunc("/.auth", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + if r.FormValue("username") != "testuser" || r.FormValue("password") != "testpass" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + mu.Lock() + loginCount++ + token := fmt.Sprintf("mock-jwt-token-%d", loginCount) + mu.Unlock() + + http.SetCookie(w, &http.Cookie{ + Name: "auth_session", + Value: token, + }) + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("/.runtime/lua_script", func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie("auth_session") + if err != nil { + t.Error("expected auth_session cookie, got none") + w.WriteHeader(http.StatusUnauthorized) + return + } + if cookie.Value == "mock-jwt-token-1" { + w.WriteHeader(http.StatusUnauthorized) + return + } + if cookie.Value != "mock-jwt-token-2" { + t.Errorf("expected refreshed cookie value, got %q", cookie.Value) + } + + result := LuaResult{Result: json.RawMessage(`"ok"`)} + _ = json.NewEncoder(w).Encode(result) + }) + + srv := newTestServer(mux) + defer srv.Close() + + client := testClientWithBasicAuth(srv.URL) + if _, err := client.ExecuteLua(context.Background(), "return 'ok'", 30); err != nil { + t.Fatalf("ExecuteLua failed: %v", err) + } + + mu.Lock() + defer mu.Unlock() + if loginCount != 2 { + t.Fatalf("expected 2 logins after rejected session, got %d", loginCount) + } +} + func TestTrailingSlashURL(t *testing.T) { // Verify that trailing slashes in the base URL are handled correctly client := New("http://localhost:3000/", Auth{})