Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implemented appendContent to persist a file handle and use Stream.IO.StreamWriter for writing. #36

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions winrmcp/cp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"log"
"os"
"strings"
"sync"

"github.com/masterzen/winrm"
Expand Down Expand Up @@ -87,6 +88,17 @@ func uploadChunks(client *winrm.Client, filePath string, maxChunks int, reader i
maxChunks = 1
}

// Construct our output and error channels. We're only responsible for
// closing the output channel when we're done with the file or we receive an
// error.
outCh, errCh, err := writeContentChannel(shell, filePath, false)
if err != nil {
return false, fmt.Errorf("Unable to write to file with shell: %w", err)
}
defer close(outCh)

// Iterate through all of the chunks whilst writing our encoded content
// chunk-by-chunk into the output channel that we constructed.
for i := 0; i < maxChunks; i++ {
n, err := reader.Read(chunk)

Expand All @@ -97,12 +109,17 @@ func uploadChunks(client *winrm.Client, filePath string, maxChunks int, reader i
return true, nil
}

// Encode our content to a string, and then write it into the output chan
content := base64.StdEncoding.EncodeToString(chunk[:n])
if err = appendContent(shell, filePath, content); err != nil {
select {
case outCh <- content:
continue

// If we got an error, then we need to just leave with a failure.
case err := <-errCh:
return false, err
}
}

return false, nil
}

Expand Down Expand Up @@ -212,9 +229,8 @@ func cleanupContent(client *winrm.Client, filePath string) error {
return nil
}

func appendContent(shell *winrm.Shell, filePath, content string) error {
cmd, err := shell.Execute(fmt.Sprintf("echo %s >> \"%s\"", content, filePath))

func executeCommandSync(shell *winrm.Shell, command string) error {
cmd, err := shell.Execute(command)
if err != nil {
return err
}
Expand All @@ -236,10 +252,78 @@ func appendContent(shell *winrm.Shell, filePath, content string) error {
if cmd.ExitCode() != 0 {
return fmt.Errorf("upload operation returned code=%d", cmd.ExitCode())
}

return nil
}

// Caller is responsible for closing the input channel, we're responsible for
// closing the error channel.
func writeContentChannel(shell *winrm.Shell, filepath string, appendTo bool) (chan string, chan error, error) {
var streamVariable string

// Generate some new temporary variables to assign our filestream, and a
// StreamWriter for it.
if name, err := tempVariable("stream_"); err != nil {
return nil, err
} else {
streamVariable = name
}

// Now we can construct a new System.IO.StreamWriter for our file, and assign it
// into its temporary variable.
if err := executeCommandSync(shell, fmt.Sprintf("$%s = New-Object -TypeName System.IO.StreamWriter %#v, $%v, %s", streamVariable, filePath, appendTo, "[System.Text.Encoding]::UTF8")); err != nil {
return nil, err
}

// Here's the channel that gets written into, and our error chan for keeping
// track of an error
input := make(chan string)
errch := make(chan error)

// Here's the goro that does the writing. We simply keep consuming strings
// from our input channel until it gets closed. Each string then gets
// written to the StreamWriter that we constructed prior.
go func() {
for content := range input {
if err := executeCommandSync(shell, fmt.Sprintf("$%s.Write(%#v)", content)); err != nil {
errch <- fmt.Errorf("Error writing to stream for temporary file %s: %w", filePath, err)
break
}
}

// Start closing our handles, and then releasing them back to the
// framework.
if err := executeCommandSync(shell, fmt.Sprintf("$%s.Close()", streamVariable)); err != nil {
log.Printf("Error closing stream for temporary file %s: %v", filePath, err)
}

if err := executeCommandSync(shell, fmt.Sprintf("$%s.Dispose($true)", streamVariable)); err != nil {
log.Printf("Error releasing stream for temporary file %s: %v", filePath, err)
}

// Now we can remove the variable we were using
if err := executeCommandSync(fmt.Sprintf("Remove-Variable -Name %s", streamVariable)); err != nil {
log.Printf("Error removing variable (%v) for temporary file %s: %w", streamVariable, filePath, err)
}

close(errch)
}()

return input, nil
}

// Generate a new variable name using a uuid
func tempVariable(prefix string) (string, error) {
uniquePart, err := uuid.NewV4()
if err != nil {
return "", err
}

// Remove all invalid characters from the uuid, and append it to the
// prefix that was given to us by the caller
variableSuffix := strings.ReplaceAll(uniquePart.String(), "-", "")
return prefix + variableSuffix, nil
}

func tempFileName() (string, error) {
uniquePart, err := uuid.NewV4()
if err != nil {
Expand Down