test: add concurrency test for digestAuthTransport to ensure thread safety
- Introduced a new test, TestDigestAuthTransportConcurrency, to validate concurrent access to the digestAuthTransport. - Implemented checks to ensure the nonce count (nc) is correctly incremented and protected from race conditions using a mutex. - Enhanced the digestAuthTransport struct to include a mutex for safe concurrent operations.
This commit is contained in:
@@ -401,6 +401,7 @@ type digestAuthTransport struct {
|
|||||||
username string
|
username string
|
||||||
password string
|
password string
|
||||||
nc int
|
nc int
|
||||||
|
ncMu sync.Mutex // Protects nc field from concurrent access
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip implements http.RoundTripper with digest auth support
|
// RoundTrip implements http.RoundTripper with digest auth support
|
||||||
@@ -452,8 +453,13 @@ func (d *digestAuthTransport) createDigestAuthHeader(req *http.Request, authHead
|
|||||||
method := req.Method
|
method := req.Method
|
||||||
ha2 := md5Hash(method + ":" + uri)
|
ha2 := md5Hash(method + ":" + uri)
|
||||||
|
|
||||||
|
// Increment nonce count atomically to prevent race conditions
|
||||||
|
// HTTP transports must be safe for concurrent use
|
||||||
|
d.ncMu.Lock()
|
||||||
d.nc++
|
d.nc++
|
||||||
ncStr := fmt.Sprintf("%08x", d.nc)
|
nc := d.nc
|
||||||
|
d.ncMu.Unlock()
|
||||||
|
ncStr := fmt.Sprintf("%08x", nc)
|
||||||
cnonce := generateNonce()
|
cnonce := generateNonce()
|
||||||
|
|
||||||
var responseStr string
|
var responseStr string
|
||||||
|
|||||||
+99
-1
@@ -1298,4 +1298,102 @@ func TestDownloadFileContextCancellation(t *testing.T) {
|
|||||||
if !strings.Contains(err.Error(), "context deadline exceeded") && !strings.Contains(err.Error(), "context canceled") {
|
if !strings.Contains(err.Error(), "context deadline exceeded") && !strings.Contains(err.Error(), "context canceled") {
|
||||||
t.Errorf("Expected context error, got: %v", err)
|
t.Errorf("Expected context error, got: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDigestAuthTransportConcurrency tests concurrent access to digestAuthTransport
|
||||||
|
// This verifies that the nc field is properly protected from race conditions
|
||||||
|
func TestDigestAuthTransportConcurrency(t *testing.T) {
|
||||||
|
nonce := "test-nonce"
|
||||||
|
realm := "test-realm"
|
||||||
|
opaque := "test-opaque"
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" || !strings.HasPrefix(authHeader, "Digest ") {
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||||
|
`Digest realm="%s", nonce="%s", opaque="%s", qop="auth"`,
|
||||||
|
realm, nonce, opaque))
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Verify nc (nonce count) is present and valid
|
||||||
|
if !strings.Contains(authHeader, "nc=") {
|
||||||
|
t.Error("Digest auth header missing nc (nonce count)")
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("success"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tr := &http.Transport{
|
||||||
|
Dial: (&net.Dialer{
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
KeepAlive: DefaultTimeout,
|
||||||
|
}).Dial,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a single transport instance that will be used concurrently
|
||||||
|
digestTransport := &digestAuthTransport{
|
||||||
|
transport: tr,
|
||||||
|
username: "admin",
|
||||||
|
password: "password",
|
||||||
|
}
|
||||||
|
|
||||||
|
digestClient := &http.Client{
|
||||||
|
Transport: digestTransport,
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make concurrent requests to verify no race conditions
|
||||||
|
const numRequests = 10
|
||||||
|
done := make(chan bool, numRequests)
|
||||||
|
errors := make(chan error, numRequests)
|
||||||
|
|
||||||
|
for i := 0; i < numRequests; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
req, err := http.NewRequest("GET", server.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
errors <- fmt.Errorf("request %d: NewRequest failed: %v", id, err)
|
||||||
|
done <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := digestClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
errors <- fmt.Errorf("request %d: Do failed: %v", id, err)
|
||||||
|
done <- true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errors <- fmt.Errorf("request %d: expected 200, got %d", id, resp.StatusCode)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all requests to complete
|
||||||
|
for i := 0; i < numRequests; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for errors
|
||||||
|
close(errors)
|
||||||
|
for err := range errors {
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that nc was incremented correctly (should be at least numRequests)
|
||||||
|
// Note: Each request triggers 2 RoundTrip calls (initial + retry with auth),
|
||||||
|
// so nc should be at least numRequests
|
||||||
|
digestTransport.ncMu.Lock()
|
||||||
|
finalNC := digestTransport.nc
|
||||||
|
digestTransport.ncMu.Unlock()
|
||||||
|
|
||||||
|
if finalNC < numRequests {
|
||||||
|
t.Errorf("Expected nc >= %d, got %d", numRequests, finalNC)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user