From 3e92aec426e8c1096ea4bcf1a13956da6fc12ed0 Mon Sep 17 00:00:00 2001 From: idkw Date: Wed, 21 Feb 2024 10:45:12 +0100 Subject: [PATCH] Fix unable to parse nmap output for incomplete XML output --- nmap.go | 15 +++++++++- nmap_test.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/nmap.go b/nmap.go index 5d10427..328717c 100644 --- a/nmap.go +++ b/nmap.go @@ -9,6 +9,7 @@ import ( "io" "os/exec" "strings" + "sync" "syscall" "time" @@ -125,15 +126,25 @@ func (s *Scanner) Run() (result *Run, warnings *[]string, err error) { stdoutDuplicate := io.TeeReader(stdoutPipe, &stdout) cmd.Stderr = &stderr + // According to cmd.StdoutPipe() doc, we must not "call Wait before all reads from the pipe have completed" + // We use this WaitGroup to wait for all IO operations to finish before calling wait + var wg sync.WaitGroup + var streamerErrs *errgroup.Group if s.streamer != nil { streamerErrs, _ = errgroup.WithContext(s.ctx) + wg.Add(1) streamerErrs.Go(func() error { + defer wg.Done() _, err = io.Copy(s.streamer, stdoutDuplicate) return err }) } else { - go io.Copy(io.Discard, stdoutDuplicate) + wg.Add(1) + go func() { + defer wg.Done() + io.Copy(io.Discard, stdoutDuplicate) + }() } // Run nmap process. @@ -145,7 +156,9 @@ func (s *Scanner) Run() (result *Run, warnings *[]string, err error) { // Add goroutine that updates chan when command is finished. done := make(chan error, 1) doneProgress := make(chan bool, 1) + go func() { + wg.Wait() err := cmd.Wait() if streamerErrs != nil { streamerError := streamerErrs.Wait() diff --git a/nmap_test.go b/nmap_test.go index b93a358..02ca6c6 100644 --- a/nmap_test.go +++ b/nmap_test.go @@ -4,10 +4,13 @@ import ( "bytes" "context" "encoding/xml" + "fmt" "io/ioutil" "os" "os/exec" "reflect" + "strings" + "sync" "testing" "time" @@ -484,3 +487,77 @@ func TestCheckStdErr(t *testing.T) { }) } } + +// Test to verify the fix for a race condition works +// See: https://github.com/Ullaakut/nmap/issues/122 +func TestParseXMLOutputRaceCondition(t *testing.T) { + scans := make(chan int, 100) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + + // Publish many scan orders + wg.Add(1) + go func() { + defer wg.Done() + for taskId := 0; taskId < 1000; taskId++ { + wg.Add(1) + scans <- taskId + } + }() + + // Consume scan orders with workers in parallel + for worker := 1; worker <= 10; worker++ { + wg.Add(1) + go func(w int) { + defer wg.Done() + for { + var taskId int + + select { + case <-ctx.Done(): + t.Logf("stopping worker %d", w) + return + case i, ok := <-scans: + if !ok { + t.Logf("stopping worker %d", w) + return + } + taskId = i + default: + t.Logf("stopping worker %d", w) + return + } + + _, err := getNmapVersion(ctx) + if err != nil { + t.Errorf("[w:%d] failed scan %d with err: %s", w, taskId, err) + } else { + t.Logf("[w:%d] completed scan %d", w, taskId) + } + wg.Done() + } + }(worker) + } + + wg.Wait() +} + +// getNmapVersion returns the version of nmap installed on the system. +// e.g. "7.80". +func getNmapVersion(ctx context.Context) (string, error) { + scanner, err := NewScanner(ctx) + if err != nil { + return "", fmt.Errorf("nmap.NewScanner: %w", err) + } + + var sb strings.Builder + scanner.Streamer(&sb) + results, warnings, err := scanner.Run() + + if err != nil { + return "", fmt.Errorf("nmap.Run: %w (%v). Result: %+v", err, warnings, sb.String()) + } + return results.Version, nil +}