From 45cf424f1f2b7112dea8a3e89ba608150656dd1c Mon Sep 17 00:00:00 2001 From: Mathias Gibbens Date: Thu, 24 Oct 2024 10:46:36 -0600 Subject: [PATCH] windows: Refactor code for injecting drivers into a library for easier shared use Signed-off-by: Mathias Gibbens (cherry picked from commit 8b651a1235f7c3df615e9cecb137a2187fd9629f) Signed-off-by: Din Music License: Apache-2.0 --- lxd-imagebuilder/main_repack-windows.go | 253 +-------------------- windows/repack_util.go | 280 ++++++++++++++++++++++++ 2 files changed, 288 insertions(+), 245 deletions(-) create mode 100644 windows/repack_util.go diff --git a/lxd-imagebuilder/main_repack-windows.go b/lxd-imagebuilder/main_repack-windows.go index 35f9aaf..83ed5bf 100644 --- a/lxd-imagebuilder/main_repack-windows.go +++ b/lxd-imagebuilder/main_repack-windows.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "context" "encoding/hex" "errors" @@ -11,7 +10,6 @@ import ( "os/exec" "path/filepath" "slices" - "strconv" "strings" lxdShared "github.com/canonical/lxd/shared" @@ -257,12 +255,14 @@ func (c *cmdRepackWindows) run(cmd *cobra.Command, args []string, overlayDir str return fmt.Errorf("Unable to find install.wim: %w", err) } - bootWimInfo, err := c.getWimInfo(bootWim) + repackUtil := windows.NewRepackUtil(c.global.ctx, c.global.flagCacheDir, c.global.logger) + + bootWimInfo, err := repackUtil.GetWimInfo(bootWim) if err != nil { return fmt.Errorf("Failed to get boot wim info: %w", err) } - installWimInfo, err := c.getWimInfo(installWim) + installWimInfo, err := repackUtil.GetWimInfo(installWim) if err != nil { return fmt.Errorf("Failed to get install wim info: %w", err) } @@ -283,14 +283,16 @@ func (c *cmdRepackWindows) run(cmd *cobra.Command, args []string, overlayDir str return errors.New("Failed to detect Windows architecture. Please provide the architecture using the --windows-arch flag") } + repackUtil.SetWindowsVersionArchitecture(c.flagWindowsVersion, c.flagWindowsArchitecture) + // This injects the drivers into the installation process - err = c.modifyWim(bootWim, bootWimInfo) + err = repackUtil.InjectDriversIntoWim(bootWim, bootWimInfo, filepath.Join(c.global.flagCacheDir, "drivers")) if err != nil { return fmt.Errorf("Failed to modify wim %q: %w", filepath.Base(bootWim), err) } // This injects the drivers into the final OS - err = c.modifyWim(installWim, installWimInfo) + err = repackUtil.InjectDriversIntoWim(installWim, installWimInfo, filepath.Join(c.global.flagCacheDir, "drivers")) if err != nil { return fmt.Errorf("Failed to modify wim %q: %w", filepath.Base(installWim), err) } @@ -339,86 +341,6 @@ func (c *cmdRepackWindows) run(cmd *cobra.Command, args []string, overlayDir str return nil } -func (c *cmdRepackWindows) getWimInfo(wimFile string) (info windows.WimInfo, err error) { - wimName := filepath.Base(wimFile) - var buf bytes.Buffer - err = shared.RunCommand(c.global.ctx, nil, &buf, "wimlib-imagex", "info", wimFile) - if err != nil { - err = fmt.Errorf("Failed to retrieve wim %q information: %w", wimName, err) - return - } - - info, err = windows.ParseWimInfo(&buf) - if err != nil { - err = fmt.Errorf("Failed to parse wim info %s: %w", wimFile, err) - return - } - - return -} - -func (c *cmdRepackWindows) modifyWim(wimFile string, info windows.WimInfo) (err error) { - wimName := filepath.Base(wimFile) - // Injects the drivers - for idx := 1; idx <= info.ImageCount(); idx++ { - name := info.Name(idx) - err = c.modifyWimIndex(wimFile, idx, name) - if err != nil { - return fmt.Errorf("Failed to modify index %d=%s of %q: %w", idx, name, wimName, err) - } - } - return -} - -func (c *cmdRepackWindows) modifyWimIndex(wimFile string, index int, name string) error { - wimIndex := strconv.Itoa(index) - wimPath := filepath.Join(c.global.flagCacheDir, "wim", wimIndex) - wimName := filepath.Base(wimFile) - - logger := c.global.logger.WithFields(logrus.Fields{"wim": strings.TrimSuffix(wimName, ".wim"), "idx": wimIndex + ":" + name}) - if !lxdShared.PathExists(wimPath) { - err := os.MkdirAll(wimPath, 0755) - if err != nil { - return fmt.Errorf("Failed to create directory %q: %w", wimPath, err) - } - } - - success := false - logger.Info("Mounting") - // Mount wim file - err := shared.RunCommand(c.global.ctx, nil, nil, "wimlib-imagex", "mountrw", wimFile, wimIndex, wimPath, "--allow-other") - if err != nil { - return fmt.Errorf("Failed to mount %q: %w", wimName, err) - } - - defer func() { - if !success { - _ = shared.RunCommand(c.global.ctx, nil, nil, "wimlib-imagex", "unmount", wimPath) - } - }() - - dirs, err := c.getWindowsDirectories(wimPath) - if err != nil { - return fmt.Errorf("Failed to get required windows directories: %w", err) - } - - logger.Info("Modifying") - // Create registry entries and copy files - err = c.injectDrivers(dirs["inf"], dirs["drivers"], dirs["filerepository"], dirs["config"]) - if err != nil { - return fmt.Errorf("Failed to inject drivers: %w", err) - } - - logger.Info("Unmounting") - err = shared.RunCommand(c.global.ctx, nil, nil, "wimlib-imagex", "unmount", wimPath, "--commit") - if err != nil { - return fmt.Errorf("Failed to unmount WIM image %q: %w", wimName, err) - } - - success = true - return nil -} - func (c *cmdRepackWindows) checkDependencies() error { dependencies := []string{"genisoimage", "hivexregedit", "rsync", "wimlib-imagex"} @@ -432,165 +354,6 @@ func (c *cmdRepackWindows) checkDependencies() error { return nil } -func (c *cmdRepackWindows) getWindowsDirectories(wimPath string) (dirs map[string]string, err error) { - dirs = map[string]string{} - dirs["inf"], err = shared.FindFirstMatch(wimPath, "windows", "inf") - if err != nil { - return nil, fmt.Errorf("Failed to determine windows/inf path: %w", err) - } - - dirs["config"], err = shared.FindFirstMatch(wimPath, "windows", "system32", "config") - if err != nil { - return nil, fmt.Errorf("Failed to determine windows/system32/config path: %w", err) - } - - dirs["drivers"], err = shared.FindFirstMatch(wimPath, "windows", "system32", "drivers") - if err != nil { - return nil, fmt.Errorf("Failed to determine windows/system32/drivers path: %w", err) - } - - dirs["filerepository"], err = shared.FindFirstMatch(wimPath, "windows", "system32", "driverstore", "filerepository") - if err != nil { - return nil, fmt.Errorf("Failed to determine windows/system32/driverstore/filerepository path: %w", err) - } - - return -} - -func (c *cmdRepackWindows) injectDrivers(infDir, driversDir, filerepositoryDir, configDir string) error { - logger := c.global.logger - - driverPath := filepath.Join(c.global.flagCacheDir, "drivers") - i := 0 - - driversRegistry := "Windows Registry Editor Version 5.00" - systemRegistry := "Windows Registry Editor Version 5.00" - softwareRegistry := "Windows Registry Editor Version 5.00" - - for driverName, driverInfo := range windows.Drivers { - logger.WithField("driver", driverName).Debug("Injecting driver") - - infFilename := fmt.Sprintf("oem%d.inf", i) - sourceDir := filepath.Join(driverPath, driverName, c.flagWindowsVersion, c.flagWindowsArchitecture) - targetBaseDir := filepath.Join(filerepositoryDir, driverInfo.PackageName) - if !lxdShared.PathExists(targetBaseDir) { - err := os.MkdirAll(targetBaseDir, 0755) - if err != nil { - return err - } - } - - for ext, dir := range map[string]string{"inf": infDir, "cat": driversDir, "dll": driversDir, "sys": driversDir} { - sourceMatches, err := shared.FindAllMatches(sourceDir, fmt.Sprintf("*.%s", ext)) - if err != nil { - logger.Debugf("failed to find first match %q %q", driverName, ext) - continue - } - - for _, sourcePath := range sourceMatches { - targetName := filepath.Base(sourcePath) - targetPath := filepath.Join(targetBaseDir, targetName) - if err = shared.Copy(sourcePath, targetPath); err != nil { - return err - } - - if ext == "cat" { - continue - } else if ext == "inf" { - targetName = infFilename - } - - targetPath = filepath.Join(dir, targetName) - if err = shared.Copy(sourcePath, targetPath); err != nil { - return err - } - } - } - - classGUID, err := windows.ParseDriverClassGUID(driverName, filepath.Join(infDir, infFilename)) - if err != nil { - return err - } - - ctx := pongo2.Context{ - "infFile": infFilename, - "packageName": driverInfo.PackageName, - "driverName": driverName, - "classGuid": classGUID, - } - - // Update Windows DRIVERS registry - if driverInfo.DriversRegistry != "" { - tpl, err := pongo2.FromString(driverInfo.DriversRegistry) - if err != nil { - return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) - } - - out, err := tpl.Execute(ctx) - if err != nil { - return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) - } - - driversRegistry = fmt.Sprintf("%s\n\n%s", driversRegistry, out) - } - - // Update Windows SYSTEM registry - if driverInfo.SystemRegistry != "" { - tpl, err := pongo2.FromString(driverInfo.SystemRegistry) - if err != nil { - return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) - } - - out, err := tpl.Execute(ctx) - if err != nil { - return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) - } - - systemRegistry = fmt.Sprintf("%s\n\n%s", systemRegistry, out) - } - - // Update Windows SOFTWARE registry - if driverInfo.SoftwareRegistry != "" { - tpl, err := pongo2.FromString(driverInfo.SoftwareRegistry) - if err != nil { - return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) - } - - out, err := tpl.Execute(ctx) - if err != nil { - return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) - } - - softwareRegistry = fmt.Sprintf("%s\n\n%s", softwareRegistry, out) - } - - i++ - } - - logger.WithField("hivefile", "DRIVERS").Debug("Updating Windows registry") - - err := shared.RunCommand(c.global.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(configDir, "DRIVERS")) - if err != nil { - return fmt.Errorf("Failed to edit Windows DRIVERS registry: %w", err) - } - - logger.WithField("hivefile", "SYSTEM").Debug("Updating Windows registry") - - err = shared.RunCommand(c.global.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(configDir, "SYSTEM")) - if err != nil { - return fmt.Errorf("Failed to edit Windows SYSTEM registry: %w", err) - } - - logger.WithField("hivefile", "SOFTWARE").Debug("Updating Windows registry") - - err = shared.RunCommand(c.global.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(configDir, "SOFTWARE")) - if err != nil { - return fmt.Errorf("Failed to edit Windows SOFTWARE registry: %w", err) - } - - return nil -} - // toHex is a pongo2 filter which converts the provided value to a hex value understood by the Windows registry. func toHex(in *pongo2.Value, param *pongo2.Value) (out *pongo2.Value, err *pongo2.Error) { dst := make([]byte, hex.EncodedLen(len(in.String()))) diff --git a/windows/repack_util.go b/windows/repack_util.go new file mode 100644 index 0000000..dea53e7 --- /dev/null +++ b/windows/repack_util.go @@ -0,0 +1,280 @@ +package windows + +import ( + "bytes" + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + + lxdShared "github.com/canonical/lxd/shared" + "github.com/flosch/pongo2/v4" + "github.com/sirupsen/logrus" + + "github.com/canonical/lxd-imagebuilder/shared" +) + +// RepackUtil is a helper struct for repacking Windows images. +type RepackUtil struct { + ctx context.Context + logger *logrus.Logger + cacheDir string + windowsVersion string + windowsArchitecture string +} + +// NewRepackUtil returns a new RepackUtil object. +func NewRepackUtil(ctx context.Context, cacheDir string, logger *logrus.Logger) RepackUtil { + return RepackUtil{ + ctx: ctx, + logger: logger, + cacheDir: cacheDir, + } +} + +// SetWindowsVersionArchitecture is a helper function for setting the specific Windows version and architecture. +func (r *RepackUtil) SetWindowsVersionArchitecture(windowsVersion string, windowsArchitecture string) { + r.windowsVersion = windowsVersion + r.windowsArchitecture = windowsArchitecture +} + +// GetWimInfo returns information about the specified wim file. +func (r *RepackUtil) GetWimInfo(wimFile string) (info WimInfo, err error) { + wimName := filepath.Base(wimFile) + var buf bytes.Buffer + err = shared.RunCommand(r.ctx, nil, &buf, "wimlib-imagex", "info", wimFile) + if err != nil { + err = fmt.Errorf("Failed to retrieve wim %q information: %w", wimName, err) + return + } + + info, err = ParseWimInfo(&buf) + if err != nil { + err = fmt.Errorf("Failed to parse wim info %s: %w", wimFile, err) + return + } + + return +} + +// InjectDriversIntoWim will inject drivers into the specified wim file. +func (r *RepackUtil) InjectDriversIntoWim(wimFile string, info WimInfo, driverPath string) (err error) { + wimName := filepath.Base(wimFile) + // Injects the drivers + for idx := 1; idx <= info.ImageCount(); idx++ { + name := info.Name(idx) + err = r.modifyWimIndex(wimFile, idx, name, driverPath) + if err != nil { + return fmt.Errorf("Failed to modify index %d=%s of %q: %w", idx, name, wimName, err) + } + } + return +} + +// InjectDrivers injects drivers from driverPath into the windowsRootPath. +func (r *RepackUtil) InjectDrivers(windowsRootPath string, driverPath string) error { + dirs, err := r.getWindowsDirectories(windowsRootPath) + if err != nil { + return fmt.Errorf("Failed to get required Windows directories under path '%s': %w", windowsRootPath, err) + } + + logger := r.logger + + i := 0 + + driversRegistry := "Windows Registry Editor Version 5.00" + systemRegistry := "Windows Registry Editor Version 5.00" + softwareRegistry := "Windows Registry Editor Version 5.00" + for driverName, driverInfo := range Drivers { + logger.WithField("driver", driverName).Debug("Injecting driver") + infFilename := fmt.Sprintf("oem-virtio-lxd%d.inf", i) + sourceDir := filepath.Join(driverPath, driverName, r.windowsVersion, r.windowsArchitecture) + targetBaseDir := filepath.Join(dirs["filerepository"], driverInfo.PackageName) + if !lxdShared.PathExists(targetBaseDir) { + err := os.MkdirAll(targetBaseDir, 0755) + if err != nil { + logger.Error(err) + return err + } + } + + for ext, dir := range map[string]string{"inf": dirs["inf"], "cat": dirs["drivers"], "dll": dirs["drivers"], "exe": dirs["drivers"], "sys": dirs["drivers"]} { + sourceMatches, err := shared.FindAllMatches(sourceDir, fmt.Sprintf("*.%s", ext)) + if err != nil { + logger.Debugf("failed to find first match %q %q", driverName, ext) + continue + } + + for _, sourcePath := range sourceMatches { + targetName := filepath.Base(sourcePath) + targetPath := filepath.Join(targetBaseDir, targetName) + if err = shared.Copy(sourcePath, targetPath); err != nil { + return err + } + + if ext == "cat" || ext == "exe" { + continue + } else if ext == "inf" { + targetName = infFilename + } + + targetPath = filepath.Join(dir, targetName) + if err = shared.Copy(sourcePath, targetPath); err != nil { + return err + } + } + } + + classGUID, err := ParseDriverClassGUID(driverName, filepath.Join(dirs["inf"], infFilename)) + if err != nil { + return err + } + + ctx := pongo2.Context{ + "infFile": infFilename, + "packageName": driverInfo.PackageName, + "driverName": driverName, + "classGuid": classGUID, + } + + // Update Windows DRIVERS registry + if driverInfo.DriversRegistry != "" { + tpl, err := pongo2.FromString(driverInfo.DriversRegistry) + if err != nil { + return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) + } + + out, err := tpl.Execute(ctx) + if err != nil { + return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) + } + + driversRegistry = fmt.Sprintf("%s\n\n%s", driversRegistry, out) + } + + // Update Windows SYSTEM registry + if driverInfo.SystemRegistry != "" { + tpl, err := pongo2.FromString(driverInfo.SystemRegistry) + if err != nil { + return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) + } + + out, err := tpl.Execute(ctx) + if err != nil { + return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) + } + + systemRegistry = fmt.Sprintf("%s\n\n%s", systemRegistry, out) + } + + // Update Windows SOFTWARE registry + if driverInfo.SoftwareRegistry != "" { + tpl, err := pongo2.FromString(driverInfo.SoftwareRegistry) + if err != nil { + return fmt.Errorf("Failed to parse template for driver %q: %w", driverName, err) + } + + out, err := tpl.Execute(ctx) + if err != nil { + return fmt.Errorf("Failed to render template for driver %q: %w", driverName, err) + } + + softwareRegistry = fmt.Sprintf("%s\n\n%s", softwareRegistry, out) + } + + i++ + } + + logger.WithField("hivefile", "DRIVERS").Debug("Updating Windows registry") + + err = shared.RunCommand(r.ctx, strings.NewReader(driversRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\DRIVERS'", filepath.Join(dirs["config"], "DRIVERS")) + if err != nil { + return fmt.Errorf("Failed to edit Windows DRIVERS registry: %w", err) + } + + logger.WithField("hivefile", "SYSTEM").Debug("Updating Windows registry") + + err = shared.RunCommand(r.ctx, strings.NewReader(systemRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SYSTEM'", filepath.Join(dirs["config"], "SYSTEM")) + if err != nil { + return fmt.Errorf("Failed to edit Windows SYSTEM registry: %w", err) + } + + logger.WithField("hivefile", "SOFTWARE").Debug("Updating Windows registry") + + err = shared.RunCommand(r.ctx, strings.NewReader(softwareRegistry), nil, "hivexregedit", "--merge", "--prefix='HKEY_LOCAL_MACHINE\\SOFTWARE'", filepath.Join(dirs["config"], "SOFTWARE")) + if err != nil { + return fmt.Errorf("Failed to edit Windows SOFTWARE registry: %w", err) + } + + return nil +} + +func (r *RepackUtil) getWindowsDirectories(rootPath string) (dirs map[string]string, err error) { + dirs = map[string]string{} + dirs["inf"], err = shared.FindFirstMatch(rootPath, "windows", "inf") + if err != nil { + return nil, fmt.Errorf("Failed to determine windows/inf path: %w", err) + } + + dirs["config"], err = shared.FindFirstMatch(rootPath, "windows", "system32", "config") + if err != nil { + return nil, fmt.Errorf("Failed to determine windows/system32/config path: %w", err) + } + + dirs["drivers"], err = shared.FindFirstMatch(rootPath, "windows", "system32", "drivers") + if err != nil { + return nil, fmt.Errorf("Failed to determine windows/system32/drivers path: %w", err) + } + + dirs["filerepository"], err = shared.FindFirstMatch(rootPath, "windows", "system32", "driverstore", "filerepository") + if err != nil { + return nil, fmt.Errorf("Failed to determine windows/system32/driverstore/filerepository path: %w", err) + } + + return +} + +func (r *RepackUtil) modifyWimIndex(wimFile string, index int, name string, driverPath string) error { + wimIndex := strconv.Itoa(index) + wimPath := filepath.Join(r.cacheDir, "wim", wimIndex) + wimName := filepath.Base(wimFile) + logger := r.logger.WithFields(logrus.Fields{"wim": wimName, "idx": wimIndex + ":" + name}) + if !lxdShared.PathExists(wimPath) { + err := os.MkdirAll(wimPath, 0755) + if err != nil { + return fmt.Errorf("Failed to create directory %q: %w", wimPath, err) + } + } + + success := false + logger.Info("Mounting") + // Mount wim file + err := shared.RunCommand(r.ctx, nil, nil, "wimlib-imagex", "mountrw", wimFile, wimIndex, wimPath, "--allow-other") + if err != nil { + return fmt.Errorf("Failed to mount %q: %w", wimName, err) + } + + defer func() { + if !success { + _ = shared.RunCommand(r.ctx, nil, nil, "wimlib-imagex", "unmount", wimPath) + } + }() + + logger.Info("Modifying") + // Create registry entries and copy files + err = r.InjectDrivers(wimPath, driverPath) + if err != nil { + return fmt.Errorf("Failed to inject drivers: %w", err) + } + + logger.Info("Unmounting") + err = shared.RunCommand(r.ctx, nil, nil, "wimlib-imagex", "unmount", wimPath, "--commit") + if err != nil { + return fmt.Errorf("Failed to unmount WIM image %q: %w", wimName, err) + } + + success = true + return nil +}