test: add concurrency test for digestAuthTransport to ensure thread safety

- Introduced a new test, TestDigestAuthTransportConcurrency, to validate concurrent access to the digestAuthTransport.
- Implemented checks to ensure the nonce count (nc) is correctly incremented and protected from race conditions using a mutex.
- Enhanced the digestAuthTransport struct to include a mutex for safe concurrent operations.
This commit is contained in:
0x524a
2025-12-01 23:35:15 -05:00
parent 1f68023dbe
commit 08d55b4cb9
2 changed files with 106 additions and 2 deletions
+7 -1
View File
@@ -401,6 +401,7 @@ type digestAuthTransport struct {
username string
password string
nc int
ncMu sync.Mutex // Protects nc field from concurrent access
}
// RoundTrip implements http.RoundTripper with digest auth support
@@ -452,8 +453,13 @@ func (d *digestAuthTransport) createDigestAuthHeader(req *http.Request, authHead
method := req.Method
ha2 := md5Hash(method + ":" + uri)
// Increment nonce count atomically to prevent race conditions
// HTTP transports must be safe for concurrent use
d.ncMu.Lock()
d.nc++
ncStr := fmt.Sprintf("%08x", d.nc)
nc := d.nc
d.ncMu.Unlock()
ncStr := fmt.Sprintf("%08x", nc)
cnonce := generateNonce()
var responseStr string
+99 -1
View File
@@ -1298,4 +1298,102 @@ func TestDownloadFileContextCancellation(t *testing.T) {
if !strings.Contains(err.Error(), "context deadline exceeded") && !strings.Contains(err.Error(), "context canceled") {
t.Errorf("Expected context error, got: %v", err)
}
}
}
// TestDigestAuthTransportConcurrency tests concurrent access to digestAuthTransport
// This verifies that the nc field is properly protected from race conditions
func TestDigestAuthTransportConcurrency(t *testing.T) {
nonce := "test-nonce"
realm := "test-realm"
opaque := "test-opaque"
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Digest ") {
w.Header().Set("WWW-Authenticate", fmt.Sprintf(
`Digest realm="%s", nonce="%s", opaque="%s", qop="auth"`,
realm, nonce, opaque))
w.WriteHeader(http.StatusUnauthorized)
return
}
// Verify nc (nonce count) is present and valid
if !strings.Contains(authHeader, "nc=") {
t.Error("Digest auth header missing nc (nonce count)")
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("success"))
}))
defer server.Close()
tr := &http.Transport{
Dial: (&net.Dialer{
Timeout: DefaultTimeout,
KeepAlive: DefaultTimeout,
}).Dial,
}
// Create a single transport instance that will be used concurrently
digestTransport := &digestAuthTransport{
transport: tr,
username: "admin",
password: "password",
}
digestClient := &http.Client{
Transport: digestTransport,
Timeout: DefaultTimeout,
}
// Make concurrent requests to verify no race conditions
const numRequests = 10
done := make(chan bool, numRequests)
errors := make(chan error, numRequests)
for i := 0; i < numRequests; i++ {
go func(id int) {
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
errors <- fmt.Errorf("request %d: NewRequest failed: %v", id, err)
done <- true
return
}
resp, err := digestClient.Do(req)
if err != nil {
errors <- fmt.Errorf("request %d: Do failed: %v", id, err)
done <- true
return
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
errors <- fmt.Errorf("request %d: expected 200, got %d", id, resp.StatusCode)
}
done <- true
}(i)
}
// Wait for all requests to complete
for i := 0; i < numRequests; i++ {
<-done
}
// Check for errors
close(errors)
for err := range errors {
if err != nil {
t.Error(err)
}
}
// Verify that nc was incremented correctly (should be at least numRequests)
// Note: Each request triggers 2 RoundTrip calls (initial + retry with auth),
// so nc should be at least numRequests
digestTransport.ncMu.Lock()
finalNC := digestTransport.nc
digestTransport.ncMu.Unlock()
if finalNC < numRequests {
t.Errorf("Expected nc >= %d, got %d", numRequests, finalNC)
}
}