Skip to content

Commit

Permalink
Ultralytics Code Refactor https://ultralytics.com/actions (#5)
Browse files Browse the repository at this point in the history
* Refactor code for speed and clarity

* Auto-format by https://ultralytics.com/actions

* Update onnxslim/core/optimizer.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Auto-format by https://ultralytics.com/actions

* Update graph.py

---------

Co-authored-by: UltralyticsAssistant <[email protected]>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 20, 2024
1 parent 5fabd67 commit 22a1d5c
Show file tree
Hide file tree
Showing 14 changed files with 195 additions and 96 deletions.
36 changes: 18 additions & 18 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,32 @@ name: CI

on:
push:
branches: [ "main" ]
branches: ["main"]
pull_request:
branches: [ "main" ]
branches: ["main"]

jobs:
build:
runs-on: self-hosted

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v3

- uses: actions/setup-python@v4
with:
python-version: '3.10'
- uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: model zoo test
run: |
python -m pip install --upgrade pip wheel setuptools
pip install .
pip install pytest onnxruntime
python tests/test_folder.py --model-dir /root/data/modelzoo
- name: model zoo test
run: |
python -m pip install --upgrade pip wheel setuptools
pip install .
pip install pytest onnxruntime
python tests/test_folder.py --model-dir /root/data/modelzoo
- name: pattern matcher test
run: |
python tests/test_pattern_matcher.py
- name: pattern matcher test
run: |
python tests/test_pattern_matcher.py
- name: pattern generator test
run: |
python tests/test_pattern_generator.py
- name: pattern generator test
run: |
python tests/test_pattern_generator.py
23 changes: 11 additions & 12 deletions .github/workflows/nightly-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,22 @@ name: nightly-test

on:
schedule:
- cron: '0 18 * * *' # Runs at 6:00 PM UTC every day, which is 2:00 AM Beijing Time the next day

- cron: "0 18 * * *" # Runs at 6:00 PM UTC every day, which is 2:00 AM Beijing Time the next day

jobs:
build:
runs-on: self-hosted

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v3

- uses: actions/setup-python@v4
with:
python-version: '3.10'
- uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: model test
run: |
python -m pip install --upgrade pip wheel setuptools
pip install .
pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu
python tests/test_onnx_nets.py
- name: model test
run: |
python -m pip install --upgrade pip wheel setuptools
pip install .
pip install pytest pytest-xdist onnxruntime timm torchvision --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu
python tests/test_onnx_nets.py
33 changes: 16 additions & 17 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,22 @@ permissions:

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel setuptools
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip wheel setuptools
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
38 changes: 33 additions & 5 deletions onnxslim/core/graph_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ def get_node_feeds(node):
elif isinstance(input, Constant):
feeds.append(input)
else:
for feed in input.inputs:
feeds.append(input if feed.op == "Split" else feed)
feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs)
return feeds


def get_name(name):
"""Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if
numeric.
"""
_illegal_char_regex = re.compile("[^0-9a-zA-Z_]+")
sanitized_name = _illegal_char_regex.sub("_", name)
if sanitized_name.isdigit():
Expand All @@ -41,16 +43,20 @@ def get_name(name):

class NodeDescriptor:
def __init__(self, node_spec):
"""Initialize NodeDescriptor with node_spec list requiring at least 4 elements."""
if not isinstance(node_spec, list):
raise ValueError("node_spec must be a list")
if len(node_spec) < 4:
raise ValueError(f"node_spec must have at least 4 elements {node_spec}")

def get_input_info(io_spec):
"""Parses io_spec to return a tuple of (integer, boolean) indicating the presence of a plus sign in the
input.
"""
if not io_spec.isdigit():
pattern_with_plus = re.search(r"(\d+)(\+)", io_spec)
if pattern_with_plus:
return int(pattern_with_plus.group(1)), True
return int(pattern_with_plus[1]), True
else:
raise ValueError(f"input_num and output_num must be integers {io_spec}")

Expand All @@ -66,46 +72,64 @@ def get_input_info(io_spec):
assert len(self.output_names) == self.output_num, f"{self.name} {len(self.output_names)} != {self.output_num}"

def __repr__(self):
"""Return a string representation of the object, including its name, operation type, input/output counts, and
input/output names.
"""
return f"name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}"

def __dict__(self):
"""Returns a dictionary representation of the object, with 'name' as the key."""
return {
"name": self,
}


class Pattern:
def __init__(self, pattern):
"""Initialize the Pattern class with a given pattern and parse its nodes."""
self.pattern = pattern
self.nodes = self.parse_nodes()

def parse_nodes(self):
"""Parse pattern into a list of NodeDescriptor objects from non-empty, stripped, and split lines."""
nodes = self.pattern.split("\n")
nodes = [line.strip().split() for line in nodes if line]
nodes = [NodeDescriptor(node) for node in nodes if node]
return nodes

def match(self, node):
"""Match a node against a precompiled pattern."""
return self.pattern.match(node)

def __repr__(self):
"""Return a string representation of the pattern attribute."""
return self.pattern


class PatternMatcher:
def __init__(self, pattern, priority):
"""Initialize the PatternMatcher with a given pattern and priority, and prepare node references and output
names.
"""
self.pattern = pattern
self.priority = priority
self.pattern_dict = {node.name: node for node in pattern.nodes}
self.output_names = [node.name for node in pattern.nodes if node.op == "output"]

def get_match_point(self):
"""Retrieve the match point node from the pattern dictionary based on output node input names."""
return self.pattern_dict[self.pattern_dict[self.output_names[0]].input_names[0]]

def match(self, node):
"""Match a given node to a pattern by comparing input names with the match point node from the pattern
dictionary.
"""
match_point = self.get_match_point()

def match_(node, pattern_node):
"""Match a given node to a pattern by comparing input names with the match point node from the pattern
dictionary.
"""
if pattern_node.op == "input":
return True

Expand Down Expand Up @@ -151,18 +175,22 @@ def match_(node, pattern_node):

@abstractmethod
def rewrite(self):
"""Abstract method to rewrite the graph based on matched patterns, to be implemented by subclasses."""
raise NotImplementedError("rewrite method must be implemented")

def parameter_check(self):
"""Check and validate parameters, returning True if valid."""
return True


class PatternGenerator:
def __init__(self, onnx_model):
"""Initialize the PatternGenerator class with an ONNX model and process its graph."""
self.graph = gs.import_onnx(onnx_model)
self.graph.fold_constants().cleanup().toposort()

def generate(self):
"""Generate the inputs, outputs, and nodes from the graph of the initialized ONNX model."""
inputs = self.graph.inputs
outputs = self.graph.outputs
nodes = self.graph.nodes
Expand All @@ -184,8 +212,8 @@ def generate(self):
template.append(
" ".join(
[node.op, name, str(len(feeds)), str(len(users))]
+ [get_name(feed.name) if not isinstance(feed, Constant) else "?" for feed in feeds]
+ [get_name(user.name) if not isinstance(user, Constant) else "?" for user in users]
+ ["?" if isinstance(feed, Constant) else get_name(feed.name) for feed in feeds]
+ ["?" if isinstance(user, Constant) else get_name(user.name) for user in users]
)
)

Expand Down
Loading

0 comments on commit 22a1d5c

Please sign in to comment.