diff --git a/lib/iommu/iommu.go b/lib/iommu/iommu.go index 43653bc..6c7c122 100644 --- a/lib/iommu/iommu.go +++ b/lib/iommu/iommu.go @@ -85,6 +85,7 @@ func (i *IOMMU) Read() { // Add the group to the IOMMU struct i.AddGroup(grp) + } else { // Add the device to the existing group ID i.Groups[group_id].AddDevice(device) @@ -136,6 +137,9 @@ func MatchSubclass(searchval string, pArg *params.Params) []string { // Get all IOMMU devices alldevs := NewIOMMU() + // Make a regex to find devices that will need special exceptions for relative search + specialRelativeExceptions := regexp.MustCompile(`^(SATA|USB) controller`) + // Iterate through the groups for id := 0; id < len(alldevs.Groups); id++ { // For each device @@ -148,16 +152,17 @@ func MatchSubclass(searchval string, pArg *params.Params) []string { devs = append(devs, line) // If we want to search for related devices - if pArg.FlagCounter["related"] > 0 && searchval != `USB controller` { + if pArg.FlagCounter["related"] > 0 && !specialRelativeExceptions.MatchString(searchval) { // Find relatives and add them to the list related_list := findRelatedDevices(device.Vendor.ID, pArg.FlagCounter["related"], pArg) devs = append(devs, related_list...) - } else if pArg.FlagCounter["related"] > 0 && searchval == `USB controller` { + } else if pArg.FlagCounter["related"] > 0 && specialRelativeExceptions.MatchString(searchval) { // Prevent an infinite loop by passing 0 instead of related other := GetDevicesFromGroups([]int{id}, 0, pArg) devs = append(devs, other...) } + } else if pArg.Flag["rom"] && pArg.Flag["gpu"] { // If we are asked to get the path to the gpu vbios if len(pArg.IntList["iommu_group"]) > 0 { @@ -169,10 +174,12 @@ func MatchSubclass(searchval string, pArg *params.Params) []string { devs = append(devs, GetRomPath(device, pArg)...) } } + } else { // Else get the vbios path for any gpu devs = append(devs, GetRomPath(device, pArg)...) } + } else { for _, group := range pArg.IntList["iommu_group"] { if id == group { @@ -184,6 +191,7 @@ func MatchSubclass(searchval string, pArg *params.Params) []string { } else if !pArg.Flag["id"] && pArg.Flag["pciaddr"] { // If --pciaddr is supplied as an argument we display the PCI Address devs = append(devs, fmt.Sprintf("%s\n", device.Address)) + } else { // Generate the device list with the data we want line := generateDevList(id, device, pArg) @@ -191,12 +199,12 @@ func MatchSubclass(searchval string, pArg *params.Params) []string { } // If we want to search for related devices - if pArg.FlagCounter["related"] > 0 && searchval != `USB controller` { + if pArg.FlagCounter["related"] > 0 && !specialRelativeExceptions.MatchString(searchval) { // Find relatives and add them to the list related_list := findRelatedDevices(device.Vendor.ID, pArg.FlagCounter["related"], pArg) devs = append(devs, related_list...) - } else if pArg.FlagCounter["related"] > 0 && searchval == `USB controller` { + } else if pArg.FlagCounter["related"] > 0 && specialRelativeExceptions.MatchString(searchval) { // Prevent an infinite loop by passing 0 instead of related other := GetDevicesFromGroups([]int{id}, 0, pArg) devs = append(devs, other...) @@ -226,6 +234,7 @@ func GetDevicesFromGroups(groups []int, related int, pArg *params.Params) []stri // Check if the IOMMU Group exists if _, iommu_num := alldevs.Groups[group]; !iommu_num { ErrorCheck(fmt.Errorf("IOMMU Group %v does not exist", group)) + } else { // For each device in specified IOMMU group for _, device := range alldevs.Groups[group].Devices { @@ -242,6 +251,7 @@ func GetDevicesFromGroups(groups []int, related int, pArg *params.Params) []stri related_list := findRelatedDevices(device.Vendor.ID, pArg.FlagCounter["related"], pArg) output = append(output, related_list...) } + } else if !strings.Contains(device.Subclass.Name, "bridge") { if pArg.Flag["id"] && !pArg.Flag["pciaddr"] { // If --id is supplied as an argument we display the VendorID:DeviceID @@ -252,6 +262,7 @@ func GetDevicesFromGroups(groups []int, related int, pArg *params.Params) []stri related_list := findRelatedDevices(device.Vendor.ID, pArg.FlagCounter["related"], pArg) output = append(output, related_list...) } + } else if !pArg.Flag["id"] && pArg.Flag["pciaddr"] { // If --pciaddr is supplied as an argument we display the PCI Address output = append(output, fmt.Sprintf("%s\n", device.Address)) diff --git a/main.go b/main.go index fff96b4..cfab36f 100644 --- a/main.go +++ b/main.go @@ -19,6 +19,7 @@ func main() { // Print the output and exit iommu.PrintOutput(output, pArg) os.Exit(0) + } else if pArg.Flag["usb"] { // Get all USB controllers output := iommu.MatchSubclass(`USB controller`, pArg) @@ -26,6 +27,7 @@ func main() { // Print the output and exit iommu.PrintOutput(output, pArg) os.Exit(0) + } else if pArg.Flag["nic"] { // Get all Ethernet controllers output := iommu.MatchSubclass(`Ethernet controller`, pArg) @@ -37,6 +39,15 @@ func main() { // Print the output and exit iommu.PrintOutput(output, pArg) os.Exit(0) + + } else if pArg.Flag["sata"] { + // Get all Ethernet controllers + output := iommu.MatchSubclass(`SATA controller`, pArg) + + // Print the output and exit + iommu.PrintOutput(output, pArg) + os.Exit(0) + } else if len(pArg.IntList["iommu_group"]) > 0 { // Get all devices in specified IOMMU groups and append it to the output output := iommu.GetDevicesFromGroups(pArg.IntList["iommu_group"], pArg.FlagCounter["related"], pArg) @@ -44,6 +55,7 @@ func main() { // Print the output and exit iommu.PrintOutput(output, pArg) os.Exit(0) + } else { // Default behaviour mimicks the bash variant that this is based on output := iommu.GetAllDevices(pArg)