diff --git a/.grit/patterns/python/_test_correct_source.md b/.grit/patterns/python/_test_correct_source.md new file mode 100644 index 00000000..28c46423 --- /dev/null +++ b/.grit/patterns/python/_test_correct_source.md @@ -0,0 +1,28 @@ +Test for ensuring new imports go to the right place, from [this issue](https://github.com/getgrit/gritql/issues/449). + +```grit +engine marzano(0.1) +language python + +`$x = 1` where { + add_import(source="typing_extensions", name="Self"), + add_import(source="pydantic", name="model_validator"), +} +``` + +Input: + +```python +import math +x = 1 +``` + +Expected output: + +```python +import math +from typing_extensions import Self +from pydantic import model_validator + +x = 1 +``` diff --git a/.grit/patterns/python/_test_two_imports.md b/.grit/patterns/python/_test_two_imports.md new file mode 100644 index 00000000..bd3a57a8 --- /dev/null +++ b/.grit/patterns/python/_test_two_imports.md @@ -0,0 +1,27 @@ +Test for adding multiple imports to the same library, for [this issue](https://github.com/getgrit/gritql/issues/450). + +```grit +engine marzano(0.1) +language python + +`x = 1` as $SELF where { + add_import(source="pydantic", name="Self"), + add_import(source="pydantic", name="pydantic1"), +} +``` + +Input: + +```python +from pydantic import BaseModel, Extra, Field, root_validator + +x = 1 +``` + +Expected Output: + +```python +from pydantic import BaseModel, Extra, Field, root_validator, Self, pydantic1 + +x = 1 +``` diff --git a/.grit/patterns/python/py_imports.grit b/.grit/patterns/python/py_imports.grit index d82252f0..b07b857f 100644 --- a/.grit/patterns/python/py_imports.grit +++ b/.grit/patterns/python/py_imports.grit @@ -9,8 +9,7 @@ pattern import_from($source, $names) { pattern before_each_file_prep_imports() { $_ where { $GLOBAL_NEW_BARE_IMPORT_NAMES = [], - $GLOBAL_NEW_FROM_IMPORT_SOURCES = [], - $GLOBAL_NEW_FROM_IMPORT_NAMES = [], + $GLOBAL_NEW_FROM_IMPORTS = [], $GLOBAL_IMPORTS_TO_REMOVE = [], } } @@ -22,26 +21,7 @@ pattern after_each_file_handle_imports() { } } -pattern process_one_source($p, $all_imports) { - [$p, $source] where { - $new_names = [], - $GLOBAL_NEW_FROM_IMPORT_NAMES <: some bubble($new_names) [$p, $name, $source] where { - $new_names += $name, - }, - if ($p <: module(statements = some import_from($source, $names))) { - $names_to_add = "", - $new_names <: some $name where { - if (!$names <: some $name) { - $names_to_add += `, $name` - } - }, - $names => `$names$names_to_add`, - } else { - $joined_names = join(list = $new_names, separator = ", "), - $all_imports += `from $source import $joined_names\n`, - } - } -} + // TODO: remove imports from the global list as we remove them pattern handle_one_removal_candidate() { @@ -93,10 +73,39 @@ predicate remove_import($source) { } } +pattern process_one_source($all_imports) { + $source where { + $maybe_new_names = [], + $GLOBAL_NEW_FROM_IMPORTS <: some bubble($maybe_new_names, $source) [$candidate_source, $name] where { + if ($source <: $candidate_source) { + $maybe_new_names += $name + } + }, + $new_names = distinct(list=$maybe_new_names), + if ($program <: module(statements = some import_from($source, $names))) { + $names_to_add = "", + $new_names <: some bubble($names_to_add, $names) $new_name where { + $names <: not some $new_name, + $names_to_add += `, $new_name` + }, + $names => `$names$names_to_add`, + } else { + $joined_names = join(list = $new_names, separator = ", "), + $all_imports += `from $source import $joined_names\n`, + } + } +} + pattern insert_imports() { $body where { $all_imports = "", - $GLOBAL_NEW_FROM_IMPORT_SOURCES <: maybe some process_one_source($p, $all_imports), + $sources = [], + // First make sure we have an enty for each import + $GLOBAL_NEW_FROM_IMPORTS <: maybe some bubble($sources) [$source, $name] where { + $sources += $source + }, + $unique_sources = distinct(list=$sources), + $unique_sources <: maybe some process_one_source($all_imports), $GLOBAL_NEW_BARE_IMPORT_NAMES <: maybe some bubble($all_imports) $name where { $all_imports += `import $name\n`, }, @@ -131,12 +140,7 @@ pattern imported_from($source) { pattern ensure_import_from($source) { $name where { if ($name <: not imported_from($source)) { - if ($GLOBAL_NEW_FROM_IMPORT_SOURCES <: not some [$program, $source]) { - $GLOBAL_NEW_FROM_IMPORT_SOURCES += [$program, $source] - }, - if ($GLOBAL_NEW_FROM_IMPORT_NAMES <: not some [$program, $name, $source]) { - $GLOBAL_NEW_FROM_IMPORT_NAMES += [$program, $name, $source] - } + $GLOBAL_NEW_FROM_IMPORTS += [$source, $name] } } }