-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(fix) Make bias statistics complete for all elements #4496
base: devel
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,6 +82,47 @@ | |
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | ||
dict_to_device(sys_stat) | ||
lst.append(sys_stat) | ||
|
||
all_elements = set() | ||
if datasets and hasattr(datasets[0], "element_to_frames"): | ||
all_elements.update(datasets[0].element_to_frames.keys()) | ||
print("we want", all_elements) | ||
|
||
collected_elements = set() | ||
for sys_stat in lst: | ||
if "atype" in sys_stat: | ||
collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) | ||
missing_elements = all_elements - collected_elements | ||
|
||
for missing_element in missing_elements: | ||
for i, dataset in enumerate(datasets): | ||
if hasattr(dataset, "element_to_frames"): | ||
frame_indices = dataset.element_to_frames.get( | ||
missing_element, [] | ||
) | ||
for frame_idx in frame_indices: | ||
if len(lst[i]["atype"]) >= nbatches: | ||
break | ||
frame_data = dataset[frame_idx] | ||
for key in frame_data: | ||
if key not in lst[i]: | ||
lst[i][key] = [] | ||
lst[i][key].append(frame_data[key]) | ||
|
||
collected_elements = set() | ||
for sys_stat in lst: | ||
if "atype" in sys_stat: | ||
collected_elements.update( | ||
np.unique(sys_stat["atype"].cpu().numpy()) | ||
) | ||
|
||
Comment on lines
+85
to
+118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Based on the code analysis, I can now provide a final response: Add validation check for missing elements after collection After the final collection of elements, add a validation check to ensure all required elements were found: # After the final collected_elements update
if all_elements - collected_elements:
missing = all_elements - collected_elements
warnings.warn(f"Could not find frames containing elements: {missing}")
🔗 Analysis chainConsider adding validation for collected elements. After collecting elements, it would be good to verify that all required elements were found and log a warning if any are still missing. Let me analyze the code further to understand the element collection and validation context. 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Check if there are any elements that might be consistently missing across datasets
# Search for element type definitions and usages
rg -A 5 "atype|element_to_frames"
Length of output: 84908 Script: #!/bin/bash
# Search for element validation or error handling related to missing elements
rg -A 5 "missing_elements|element_to_frames|all_elements"
Length of output: 2630 Check notice Code scanning / CodeQL Nested loops with same variable Note
Nested for statement uses loop variable 'sys_stat' of enclosing
for statement Error loading related location Loading |
||
for sys_stat in lst: | ||
for key in sys_stat: | ||
if isinstance(sys_stat[key], list) and isinstance( | ||
sys_stat[key][0], torch.Tensor | ||
): | ||
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | ||
|
||
return lst | ||
|
||
|
||
|
Check warning
Code scanning / CodeQL
Unreachable code Warning