Skip to content

Commit

Permalink
Merge branch 'main' into batch-matmul-no-hybrid
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Aug 16, 2023
2 parents de92db3 + f4c2d08 commit da14c7d
Show file tree
Hide file tree
Showing 16 changed files with 288 additions and 18 deletions.
9 changes: 9 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
#
# TFLM Bazel configuration file.

# The semver-format version label embedded in build outputs when and where
# stamping is used. Note TFLM does not currently publish semver-versioned
# releases; however, this value is used where a version label is required, such
# as in the Python distribution package.
build --embed_label=0

# Get stamp values from a script's output
build --workspace_status_command=./tools/workspace_status.sh

# Use the following C++ standard
build --cxxopt -std=c++17

Expand Down
1 change: 1 addition & 0 deletions codegen/BUILD
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
load("@rules_python//python:defs.bzl", "py_binary", "py_library")
load("@tflm_pip_deps//:requirements.bzl", "requirement")

package(
Expand Down
21 changes: 20 additions & 1 deletion python/tflite_micro/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("//python:py_namespace.bzl", "py_namespace")
load("//tools:expand_stamp_vars.bzl", "expand_stamp_vars")
load("@rules_python//python:packaging.bzl", "py_package", "py_wheel")
load("@tflm_pip_deps//:requirements.bzl", "requirement")
load(
Expand Down Expand Up @@ -109,6 +110,14 @@ py_library(
],
)

# Generate a version attribute, imported as tflite_micro.__version__, using
# stamp (a.k.a. workspace status) variables.
expand_stamp_vars(
name = "version",
out = "_version.py",
template = "_version.py.in",
)

# Collect the `deps` and their transitive dependences together into a set of
# files to package. The files retain their full path relative to the workspace
# root, which determines the subpackage path at which they're located within
Expand All @@ -129,6 +138,7 @@ py_package(
deps = [
":postinstall_check",
":runtime",
":version",
],
)

Expand All @@ -151,16 +161,25 @@ py_namespace(
],
)

expand_stamp_vars(
name = "description_file",
out = "README.pypi.md",
template = "README.pypi.md.in",
)

py_wheel(
name = "whl",
description_file = ":description_file",
distribution = "tflite_micro",
requires = [
"flatbuffers",
"numpy",
"tensorflow",
],
stamp = 1, # 1 == always stamp
strip_path_prefixes = [package_name()],
version = "0.1.0",
summary = "TensorFlow Lite for Microcontrollers",
version = "{BUILD_EMBED_LABEL}.dev{STABLE_GIT_COMMIT_TIME}",
deps = [
":namespace",
],
Expand Down
5 changes: 5 additions & 0 deletions python/tflite_micro/README.pypi.md.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# TensorFlow Lite for Microcontrollers

This package is built from commit
[{STABLE_GIT_HASH}](https://github.com/tensorflow/tflite-micro/blob/{STABLE_GIT_HASH}/python/tflite_micro)
of [github.com/tensorflow/tflite-micro](https://github.com/tensorflow/tflite-micro).
4 changes: 4 additions & 0 deletions python/tflite_micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----

# Define a public API for the package by providing aliases for modules which
# are otherwise deeply nested in subpackages determined by their location in
Expand All @@ -20,5 +21,8 @@

from tflite_micro.python.tflite_micro import runtime

# Unambiguously identify the source used to build the package.
from tflite_micro.python.tflite_micro._version import __version__

# Ordered after `runtime` to avoid a circular dependency
from tflite_micro.python.tflite_micro import postinstall_check
1 change: 1 addition & 0 deletions python/tflite_micro/_version.py.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "{BUILD_EMBED_LABEL}.dev{STABLE_GIT_COMMIT_TIME}-g{STABLE_GIT_HASH}"
3 changes: 1 addition & 2 deletions python/tflite_micro/signal/utils/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Signal python utilities.
load("@rules_python//python:defs.bzl", "py_library")
load("@rules_python//python:defs.bzl", "py_test")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("@tflm_pip_deps//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

Expand Down
6 changes: 3 additions & 3 deletions python/tflite_micro/whl_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pip show --files tflite-micro

# Run the package's post-installation checks.
python3 << HEREDOC
import sys
from tflite_micro import postinstall_check
sys.exit(0 if postinstall_check.passed() else 1)
import sys, tflite_micro
print(tflite_micro.__version__)
sys.exit(0 if tflite_micro.postinstall_check.passed() else 1)
HEREDOC
1 change: 0 additions & 1 deletion tensorflow/lite/micro/examples/person_detection/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Description:
# TensorFlow Lite for Microcontrollers Vision Example.
load("@rules_python//python:defs.bzl", "py_binary", "py_library", "py_test")
load("//tensorflow/lite/micro:build_def.bzl", "generate_cc_arrays")

package(
Expand Down
14 changes: 8 additions & 6 deletions tensorflow/lite/micro/kernels/testdata/lstm_test_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,14 @@ NodeQuantizationParameters Get2X2Int16LstmQuantizationSettings() {

// state quantization parameters
quantization_settings.input = {/*scale=*/3.0518044e-5, /*zp=*/0,
/*symmetry=*/false};
quantization_settings.output = {/*scale=*/1.8310826e-5, /*zp=*/-5461,
/*symmetry=*/false};
quantization_settings.hidden_state = {/*scale=*/1.8310826e-5, /*zp=*/-5461,
/*symmetry=*/false};
quantization_settings.cell_state = {/*scale=*/0.00024414062, /*zp=*/0,
/*symmetry=*/true};
quantization_settings.output = {/*scale=*/2.1362956633198035e-05, /*zp=*/0,
/*symmetry=*/true};
quantization_settings.hidden_state = {/*scale=*/2.1362956633198035e-05,
/*zp=*/0,
/*symmetry=*/true};
quantization_settings.cell_state = {/*scale=*/0.00024414807580797754,
/*zp=*/0,
/*symmetry=*/true};

// gate quantization parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
2. Print the intermediate step outputs inside the LSTM for a single step LSTM invocation (Get2X2GateOutputCheckData in .cc)
3. Print the outputs for multi-step LSTM invocation (Get2X2LstmEvalCheckData in .cc)
Every invocation gives three types information:
1. Quantized output: kernel output in integer
Every invocation gives three types information:
1. Quantized output: kernel output in integer
2. Dequantized output: Quantized output in floating point representation
3. Float output: output from the floating point computation (i.e., float kernel)
Note:
Note:
1. Change quantization settings in _KERNEL_CONFIG to see the outcomes from various quantization schema (e.g., 8x8 Vs. 16x8)
2. Only single batch inference is supporte here. Change _GATE_TEST_DATA or _MULTISTEP_TEST_DATA to see kernel outputs on different input data
3. The quantization computation here is not the exact as the c++ implementation. The integer calculation is mimiced here using floating point.
3. The quantization computation here is not the exact as the c++ implementation. The integer calculation is emulated here using floating point.
No fixed point math is implemented here. The purpose is to illustrate the computation procedure and possible quantization error accumulation, not for bit exactness.
"""
from absl import app
Expand Down Expand Up @@ -88,7 +88,7 @@
_MULTISTEP_TEST_DATA = {
'init_hidden_state_vals': [0, 0],
'init_cell_state_vals': [0, 0],
'input_data': [0.2, 0.3, 0.2, 0.3, 0.2, 0.3], # three time steps
'input_data': [0.2, 0.3, 0.2, 0.3, 0.2, 0.3], # three time steps
'hidden_state_range': (-0.5, 0.7),
'cell_state_range': [-8, 8],
'input_data_range': [-1, 1]
Expand Down
16 changes: 16 additions & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("@rules_python//python:defs.bzl", "py_binary", "py_test")

package(
default_visibility = ["//visibility:public"],
)

py_binary(
name = "expand_stamp_vars",
srcs = ["expand_stamp_vars.py"],
)

py_test(
name = "expand_stamp_vars_test",
srcs = ["expand_stamp_vars_test.py"],
deps = [":expand_stamp_vars"],
)
52 changes: 52 additions & 0 deletions tools/expand_stamp_vars.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----

def expand_stamp_vars(name, template, out):
"""Macro for expanding a template using workspace status variables.
Typical usage in a BUILD file:
expand_stamp_vars(
name = "version",
template = "_version.py.in",
out = "_version.py",
)
Writes `template` to `out`, expanding references of the form '{KEY}' to the
value of the corresponding Bazel workspace status variable.
"""

# This macro uses a genrule to call a helper program at Bazel execution
# time, because workspace variables are not available until execution time.
# Workspace variables are generated by bazel on each invocation, and
# written to the hardcoded files names used below. See the Bazel
# documentation for the option --workspace_status_command.

native.genrule(
name = name,
srcs = [template],
outs = [out],
cmd = "$(location //tools:expand_stamp_vars) " +
"bazel-out/stable-status.txt " +
"bazel-out/volatile-status.txt " +
"<$< >$@",
tools = [
"//tools:expand_stamp_vars",
],

# Undocumented, but valid, and the only way to declare the necessary
# dependencies on {stable,volatile}-status.txt.
stamp = 1,
)
75 changes: 75 additions & 0 deletions tools/expand_stamp_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3

# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----

# A filter that expands Bazel workspace stamp variables.
#
# For example, the input steam:
#
# This build was compiled at {BUILD_DATE}.
#
# is expanded into the output stream:
#
# This build was compiled at 2023-02-10T14:15.
#
# Stamp variable key-value pairs are read from all files passed as positional
# arguments. These files are typically bazel-out/stable-status.txt and
# bazel-out/volatile-status.txt. See the Bazel documentation for the option
# --workspace_status_command.

import sys


def read_stamps(file):
"""Return a dictionary of key-value pairs read from a stamp file.
These files are typically bazel-out/stable-status.txt and
bazel-out/volatile-status.txt. See the Bazel documentation for the option
--workspace_status_command."""

stamps = {}
for line in file:
try:
key, value = line.split(" ", maxsplit=1)
stamps[key] = value.strip()
except ValueError:
pass # Skip blank lines, etc.

return stamps


def expand(istream, ostream, stamps):
"""Write istream to ostream, expanding placeholders like {KEY}."""
for line in istream:
for key, value in stamps.items():
line = line.replace(f"{{{key}}}", value)
ostream.write(line)


def _main():
"""Stamp variables are read from all files passed as positional arguments."""
stamps = {}
for name in sys.argv[1:]:
with open(name) as f:
stamps.update(read_stamps(f))

expand(sys.stdin, sys.stdout, stamps)

sys.exit(0)


if __name__ == "__main__":
_main()
46 changes: 46 additions & 0 deletions tools/expand_stamp_vars_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python3

# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----

# A test for the filter that expands Bazel workspace stamp variables.

from tools import expand_stamp_vars

import io
import unittest


class FilterTest(unittest.TestCase):
"""A simple test of the expansion feature."""

def test_basic(self):
stamps = """
BUILD_STAMP_ONE value_one
BUILD_STAMP_TWO value_two
"""
input = "This is {BUILD_STAMP_TWO}. This is {BUILD_STAMP_ONE}."
golden = "This is value_two. This is value_one."

istream = io.StringIO(input)
ostream = io.StringIO()
stamps = expand_stamp_vars.read_stamps(io.StringIO(stamps))
expand_stamp_vars.expand(istream, ostream, stamps)

self.assertEqual(ostream.getvalue(), golden)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit da14c7d

Please sign in to comment.