diff --git a/attack.go b/attack.go index b9fd620..d76a94c 100644 --- a/attack.go +++ b/attack.go @@ -101,13 +101,7 @@ func DetectAuthMethods(c Curler, targets []Stream, timeout time.Duration, log bo attacks := make(chan Stream) defer close(attacks) - validate := v.New() for i := range targets { - err := validate.Struct(targets[i]) - if err != nil { - return targets, errors.Wrap(err, "invalid targets") - } - targets[i].AuthenticationType = detectAuthMethod(c, targets[i], timeout, log) } diff --git a/attack_test.go b/attack_test.go index 5849616..f2e0062 100644 --- a/attack_test.go +++ b/attack_test.go @@ -2,14 +2,12 @@ package cmrdr import ( "errors" - "fmt" - "os" "testing" "time" - curl "github.com/ullaakut/go-curl" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + curl "github.com/ullaakut/go-curl" ) type CurlerMock struct { @@ -163,14 +161,12 @@ func TestAttackCredentials(t *testing.T) { if len(test.expectedErrMsg) > 0 { if err == nil { - fmt.Printf("unexpected success in AttackCredentials test, iteration %d. expected error: %s\n", i, test.expectedErrMsg) - os.Exit(1) + t.Errorf("unexpected success in AttackCredentials test, iteration %d. expected error: %s\n", i, test.expectedErrMsg) } assert.Contains(t, err.Error(), test.expectedErrMsg, "wrong error message") } else { if err != nil { - fmt.Printf("unexpected error in AttackCredentials test, iteration %d: %v\n", i, err) - os.Exit(1) + t.Errorf("unexpected error in AttackCredentials test, iteration %d: %v\n", i, err) } for _, stream := range test.expectedStreams { foundStream := false @@ -318,15 +314,13 @@ func TestAttackRoute(t *testing.T) { if len(test.expectedErrMsg) > 0 { if err == nil { - fmt.Printf("unexpected success in AttackRoute test, iteration %d. expected error: %s\n", i, test.expectedErrMsg) - os.Exit(1) + t.Errorf("unexpected success in AttackRoute test, iteration %d. expected error: %s\n", i, test.expectedErrMsg) } assert.Contains(t, err.Error(), test.expectedErrMsg, "wrong error message") } else { if err != nil { - fmt.Printf("unexpected error in AttackRoute test, iteration %d: %v\n", i, err) - os.Exit(1) + t.Errorf("unexpected error in AttackRoute test, iteration %d: %v\n", i, err) } for _, stream := range test.expectedStreams { @@ -486,15 +480,130 @@ func TestValidateStreams(t *testing.T) { if len(tC.expectedErrMsg) > 0 { if err == nil { - fmt.Printf("unexpected success in ValidateStream test, iteration %d. expected error: %s\n", i, tC.expectedErrMsg) - os.Exit(1) + t.Errorf("unexpected success in ValidateStream test, iteration %d. expected error: %s\n", i, tC.expectedErrMsg) } assert.Contains(t, err.Error(), tC.expectedErrMsg, "wrong error message") } else { if err != nil { - fmt.Printf("unexpected error in ValidateStream test, iteration %d: %v\n", i, err) - os.Exit(1) + t.Errorf("unexpected error in ValidateStream test, iteration %d: %v\n", i, err) + } + + for _, stream := range tC.expectedStreams { + foundStream := false + for _, result := range results { + if result.Address == stream.Address && result.Device == stream.Device && result.Port == stream.Port { + foundStream = true + } + } + + assert.Equal(t, true, foundStream, "wrong streams parsed") + } + } + + assert.Equal(t, len(tC.expectedStreams), len(results), "wrong streams parsed") + + curlerMock.AssertExpectations(t) + }) + } +} + +func TestDetectAuthenticationType(t *testing.T) { + validStream1 := Stream{ + Device: "fakeDevice", + Address: "fakeAddress", + Port: 1337, + Available: true, + } + + validStream2 := Stream{ + Device: "fakeDevice", + Address: "differentFakeAddress", + Port: 1337, + Available: true, + } + + fakeTargets := []Stream{validStream1, validStream2} + + testCases := []struct { + desc string + + targets []Stream + timeout time.Duration + log bool + + status int + + performErr error + getInfoErr error + + expectedStreams []Stream + expectedErrMsg string + }{ + // curl getinfo fails + { + desc: "curl getinfo fails", + + targets: fakeTargets, + timeout: 1 * time.Millisecond, + + getInfoErr: errors.New("dummy error"), + + expectedStreams: fakeTargets, + }, + // curl perform fails + { + desc: "curl perform fails", + + targets: fakeTargets, + timeout: 1 * time.Millisecond, + + performErr: errors.New("dummy error"), + + expectedStreams: fakeTargets, + }, + // Logs disabled + { + desc: "logs disabled", + + targets: fakeTargets, + timeout: 1 * time.Millisecond, + log: false, + + expectedStreams: fakeTargets, + }, + // Logs enabled + { + desc: "logs enabled", + + targets: fakeTargets, + timeout: 1 * time.Millisecond, + log: true, + + expectedStreams: fakeTargets, + }, + } + for i, tC := range testCases { + t.Run(tC.desc, func(t *testing.T) { + curlerMock := &CurlerMock{} + + curlerMock.On("Setopt", mock.Anything, mock.Anything).Return(nil) + curlerMock.On("Perform").Return(tC.performErr) + if tC.performErr == nil { + curlerMock.On("Getinfo", mock.Anything).Return(tC.status, tC.getInfoErr) + } + + results, err := DetectAuthMethods(curlerMock, tC.targets, tC.timeout, tC.log) + + if len(tC.expectedErrMsg) > 0 { + if err == nil { + t.Errorf("unexpected success in DetectAuthMethods test, iteration %d. expected error: %s\n", i, tC.expectedErrMsg) + } + + assert.Contains(t, err.Error(), tC.expectedErrMsg, "wrong error message") + } else { + if err != nil { + t.Errorf("unexpected error in DetectAuthMethods test, iteration %d: %v\n", i, err) } for _, stream := range tC.expectedStreams {