diff --git a/client.go b/client.go index 26a25bf..8452fa0 100644 --- a/client.go +++ b/client.go @@ -401,6 +401,7 @@ type digestAuthTransport struct { username string password string nc int + ncMu sync.Mutex // Protects nc field from concurrent access } // RoundTrip implements http.RoundTripper with digest auth support @@ -452,8 +453,13 @@ func (d *digestAuthTransport) createDigestAuthHeader(req *http.Request, authHead method := req.Method ha2 := md5Hash(method + ":" + uri) + // Increment nonce count atomically to prevent race conditions + // HTTP transports must be safe for concurrent use + d.ncMu.Lock() d.nc++ - ncStr := fmt.Sprintf("%08x", d.nc) + nc := d.nc + d.ncMu.Unlock() + ncStr := fmt.Sprintf("%08x", nc) cnonce := generateNonce() var responseStr string diff --git a/client_test.go b/client_test.go index 7a58c92..6cb5555 100644 --- a/client_test.go +++ b/client_test.go @@ -1298,4 +1298,102 @@ func TestDownloadFileContextCancellation(t *testing.T) { if !strings.Contains(err.Error(), "context deadline exceeded") && !strings.Contains(err.Error(), "context canceled") { t.Errorf("Expected context error, got: %v", err) } -} \ No newline at end of file +} + +// TestDigestAuthTransportConcurrency tests concurrent access to digestAuthTransport +// This verifies that the nc field is properly protected from race conditions +func TestDigestAuthTransportConcurrency(t *testing.T) { + nonce := "test-nonce" + realm := "test-realm" + opaque := "test-opaque" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if authHeader == "" || !strings.HasPrefix(authHeader, "Digest ") { + w.Header().Set("WWW-Authenticate", fmt.Sprintf( + `Digest realm="%s", nonce="%s", opaque="%s", qop="auth"`, + realm, nonce, opaque)) + w.WriteHeader(http.StatusUnauthorized) + return + } + // Verify nc (nonce count) is present and valid + if !strings.Contains(authHeader, "nc=") { + t.Error("Digest auth header missing nc (nonce count)") + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + })) + defer server.Close() + + tr := &http.Transport{ + Dial: (&net.Dialer{ + Timeout: DefaultTimeout, + KeepAlive: DefaultTimeout, + }).Dial, + } + + // Create a single transport instance that will be used concurrently + digestTransport := &digestAuthTransport{ + transport: tr, + username: "admin", + password: "password", + } + + digestClient := &http.Client{ + Transport: digestTransport, + Timeout: DefaultTimeout, + } + + // Make concurrent requests to verify no race conditions + const numRequests = 10 + done := make(chan bool, numRequests) + errors := make(chan error, numRequests) + + for i := 0; i < numRequests; i++ { + go func(id int) { + req, err := http.NewRequest("GET", server.URL, nil) + if err != nil { + errors <- fmt.Errorf("request %d: NewRequest failed: %v", id, err) + done <- true + return + } + + resp, err := digestClient.Do(req) + if err != nil { + errors <- fmt.Errorf("request %d: Do failed: %v", id, err) + done <- true + return + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + errors <- fmt.Errorf("request %d: expected 200, got %d", id, resp.StatusCode) + } + done <- true + }(i) + } + + // Wait for all requests to complete + for i := 0; i < numRequests; i++ { + <-done + } + + // Check for errors + close(errors) + for err := range errors { + if err != nil { + t.Error(err) + } + } + + // Verify that nc was incremented correctly (should be at least numRequests) + // Note: Each request triggers 2 RoundTrip calls (initial + retry with auth), + // so nc should be at least numRequests + digestTransport.ncMu.Lock() + finalNC := digestTransport.nc + digestTransport.ncMu.Unlock() + + if finalNC < numRequests { + t.Errorf("Expected nc >= %d, got %d", numRequests, finalNC) + } +}