Skip to content

Commit

Permalink
fix: repair some bugs with add_import in Python (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante authored Aug 8, 2024
1 parent e069cde commit 41f89f7
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 29 deletions.
28 changes: 28 additions & 0 deletions .grit/patterns/python/_test_correct_source.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Test for ensuring new imports go to the right place, from [this issue](https://github.com/getgrit/gritql/issues/449).

```grit
engine marzano(0.1)
language python
`$x = 1` where {
add_import(source="typing_extensions", name="Self"),
add_import(source="pydantic", name="model_validator"),
}
```

Input:

```python
import math
x = 1
```

Expected output:

```python
import math
from typing_extensions import Self
from pydantic import model_validator

x = 1
```
27 changes: 27 additions & 0 deletions .grit/patterns/python/_test_two_imports.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Test for adding multiple imports to the same library, for [this issue](https://github.com/getgrit/gritql/issues/450).

```grit
engine marzano(0.1)
language python
`x = 1` as $SELF where {
add_import(source="pydantic", name="Self"),
add_import(source="pydantic", name="pydantic1"),
}
```

Input:

```python
from pydantic import BaseModel, Extra, Field, root_validator

x = 1
```

Expected Output:

```python
from pydantic import BaseModel, Extra, Field, root_validator, Self, pydantic1

x = 1
```
62 changes: 33 additions & 29 deletions .grit/patterns/python/py_imports.grit
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ pattern import_from($source, $names) {
pattern before_each_file_prep_imports() {
$_ where {
$GLOBAL_NEW_BARE_IMPORT_NAMES = [],
$GLOBAL_NEW_FROM_IMPORT_SOURCES = [],
$GLOBAL_NEW_FROM_IMPORT_NAMES = [],
$GLOBAL_NEW_FROM_IMPORTS = [],
$GLOBAL_IMPORTS_TO_REMOVE = [],
}
}
Expand All @@ -22,26 +21,7 @@ pattern after_each_file_handle_imports() {
}
}

pattern process_one_source($p, $all_imports) {
[$p, $source] where {
$new_names = [],
$GLOBAL_NEW_FROM_IMPORT_NAMES <: some bubble($new_names) [$p, $name, $source] where {
$new_names += $name,
},
if ($p <: module(statements = some import_from($source, $names))) {
$names_to_add = "",
$new_names <: some $name where {
if (!$names <: some $name) {
$names_to_add += `, $name`
}
},
$names => `$names$names_to_add`,
} else {
$joined_names = join(list = $new_names, separator = ", "),
$all_imports += `from $source import $joined_names\n`,
}
}
}


// TODO: remove imports from the global list as we remove them
pattern handle_one_removal_candidate() {
Expand Down Expand Up @@ -93,10 +73,39 @@ predicate remove_import($source) {
}
}

pattern process_one_source($all_imports) {
$source where {
$maybe_new_names = [],
$GLOBAL_NEW_FROM_IMPORTS <: some bubble($maybe_new_names, $source) [$candidate_source, $name] where {
if ($source <: $candidate_source) {
$maybe_new_names += $name
}
},
$new_names = distinct(list=$maybe_new_names),
if ($program <: module(statements = some import_from($source, $names))) {
$names_to_add = "",
$new_names <: some bubble($names_to_add, $names) $new_name where {
$names <: not some $new_name,
$names_to_add += `, $new_name`
},
$names => `$names$names_to_add`,
} else {
$joined_names = join(list = $new_names, separator = ", "),
$all_imports += `from $source import $joined_names\n`,
}
}
}

pattern insert_imports() {
$body where {
$all_imports = "",
$GLOBAL_NEW_FROM_IMPORT_SOURCES <: maybe some process_one_source($p, $all_imports),
$sources = [],
// First make sure we have an enty for each import
$GLOBAL_NEW_FROM_IMPORTS <: maybe some bubble($sources) [$source, $name] where {
$sources += $source
},
$unique_sources = distinct(list=$sources),
$unique_sources <: maybe some process_one_source($all_imports),
$GLOBAL_NEW_BARE_IMPORT_NAMES <: maybe some bubble($all_imports) $name where {
$all_imports += `import $name\n`,
},
Expand Down Expand Up @@ -131,12 +140,7 @@ pattern imported_from($source) {
pattern ensure_import_from($source) {
$name where {
if ($name <: not imported_from($source)) {
if ($GLOBAL_NEW_FROM_IMPORT_SOURCES <: not some [$program, $source]) {
$GLOBAL_NEW_FROM_IMPORT_SOURCES += [$program, $source]
},
if ($GLOBAL_NEW_FROM_IMPORT_NAMES <: not some [$program, $name, $source]) {
$GLOBAL_NEW_FROM_IMPORT_NAMES += [$program, $name, $source]
}
$GLOBAL_NEW_FROM_IMPORTS += [$source, $name]
}
}
}
Expand Down

0 comments on commit 41f89f7

Please sign in to comment.