test: add concurrency test for digestAuthTransport to ensure thread safety
- Introduced a new test, TestDigestAuthTransportConcurrency, to validate concurrent access to the digestAuthTransport. - Implemented checks to ensure the nonce count (nc) is correctly incremented and protected from race conditions using a mutex. - Enhanced the digestAuthTransport struct to include a mutex for safe concurrent operations.
This commit is contained in:
@@ -401,6 +401,7 @@ type digestAuthTransport struct {
|
||||
username string
|
||||
password string
|
||||
nc int
|
||||
ncMu sync.Mutex // Protects nc field from concurrent access
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper with digest auth support
|
||||
@@ -452,8 +453,13 @@ func (d *digestAuthTransport) createDigestAuthHeader(req *http.Request, authHead
|
||||
method := req.Method
|
||||
ha2 := md5Hash(method + ":" + uri)
|
||||
|
||||
// Increment nonce count atomically to prevent race conditions
|
||||
// HTTP transports must be safe for concurrent use
|
||||
d.ncMu.Lock()
|
||||
d.nc++
|
||||
ncStr := fmt.Sprintf("%08x", d.nc)
|
||||
nc := d.nc
|
||||
d.ncMu.Unlock()
|
||||
ncStr := fmt.Sprintf("%08x", nc)
|
||||
cnonce := generateNonce()
|
||||
|
||||
var responseStr string
|
||||
|
||||
+99
-1
@@ -1298,4 +1298,102 @@ func TestDownloadFileContextCancellation(t *testing.T) {
|
||||
if !strings.Contains(err.Error(), "context deadline exceeded") && !strings.Contains(err.Error(), "context canceled") {
|
||||
t.Errorf("Expected context error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestDigestAuthTransportConcurrency tests concurrent access to digestAuthTransport
|
||||
// This verifies that the nc field is properly protected from race conditions
|
||||
func TestDigestAuthTransportConcurrency(t *testing.T) {
|
||||
nonce := "test-nonce"
|
||||
realm := "test-realm"
|
||||
opaque := "test-opaque"
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" || !strings.HasPrefix(authHeader, "Digest ") {
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
|
||||
`Digest realm="%s", nonce="%s", opaque="%s", qop="auth"`,
|
||||
realm, nonce, opaque))
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// Verify nc (nonce count) is present and valid
|
||||
if !strings.Contains(authHeader, "nc=") {
|
||||
t.Error("Digest auth header missing nc (nonce count)")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tr := &http.Transport{
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: DefaultTimeout,
|
||||
KeepAlive: DefaultTimeout,
|
||||
}).Dial,
|
||||
}
|
||||
|
||||
// Create a single transport instance that will be used concurrently
|
||||
digestTransport := &digestAuthTransport{
|
||||
transport: tr,
|
||||
username: "admin",
|
||||
password: "password",
|
||||
}
|
||||
|
||||
digestClient := &http.Client{
|
||||
Transport: digestTransport,
|
||||
Timeout: DefaultTimeout,
|
||||
}
|
||||
|
||||
// Make concurrent requests to verify no race conditions
|
||||
const numRequests = 10
|
||||
done := make(chan bool, numRequests)
|
||||
errors := make(chan error, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(id int) {
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("request %d: NewRequest failed: %v", id, err)
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := digestClient.Do(req)
|
||||
if err != nil {
|
||||
errors <- fmt.Errorf("request %d: Do failed: %v", id, err)
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errors <- fmt.Errorf("request %d: expected 200, got %d", id, resp.StatusCode)
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
for i := 0; i < numRequests; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Check for errors
|
||||
close(errors)
|
||||
for err := range errors {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that nc was incremented correctly (should be at least numRequests)
|
||||
// Note: Each request triggers 2 RoundTrip calls (initial + retry with auth),
|
||||
// so nc should be at least numRequests
|
||||
digestTransport.ncMu.Lock()
|
||||
finalNC := digestTransport.nc
|
||||
digestTransport.ncMu.Unlock()
|
||||
|
||||
if finalNC < numRequests {
|
||||
t.Errorf("Expected nc >= %d, got %d", numRequests, finalNC)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user