diff --git a/_sources/development.md.txt b/_sources/development.md.txt
new file mode 100644
index 0000000..23ae97a
--- /dev/null
+++ b/_sources/development.md.txt
@@ -0,0 +1,39 @@
+# Development
+
+For users who wish to develop using this codebase, the following setup is required:
+
+**First-time setup**:
+
+```bash
+python3 -m venv .venv
+source .venv/bin/activate
+pip install -r requirements-dev.txt # Or requirements-dev-ipu.txt for the ipu
+```
+
+**Subsequent setup**:
+
+```bash
+source .venv/bin/activate
+```
+
+**Run pre-flight checks** (or run `./dev --help` to see supported commands):
+
+```bash
+./dev
+```
+
+**IDE recommendations**:
+
+- Python intepreter is set to `.venv/bin/python`
+- Format-on-save enabled
+- Consider a `.env` file for setting `PYTHONPATH`, for example `echo "PYTHONPATH=$(pwd)" > .env`
+ (note that this will be a different path if using devcontainers)
+
+**Docs development**:
+
+```bash
+cd docs/
+make html
+```
+
+then view `docs/_build/html/index.html` in your browser.
\ No newline at end of file
diff --git a/_sources/development.rst.txt b/_sources/development.rst.txt
deleted file mode 100644
index 339dc8f..0000000
--- a/_sources/development.rst.txt
+++ /dev/null
@@ -1,2 +0,0 @@
-.. include:: ../README.md
- :parser: myst_parser.sphinx_
diff --git a/development.html b/development.html
index 4711169..c81d2a4 100644
--- a/development.html
+++ b/development.html
@@ -4,7 +4,7 @@
-
2. Unit-Scaled Maximal Update Parameterization (u-μP) — unit-scaling documentation
+ 2. Development — unit-scaling documentation
@@ -47,13 +47,7 @@
Contents
1. User guide
-2. Developer guide
-
+2. Developer guide
3. Limitations
4. Blog
5. API reference
@@ -73,9 +67,9 @@
@@ -83,32 +77,8 @@
-
-2. Unit-Scaled Maximal Update Parameterization (u-μP)
-A library for unit scaling in PyTorch, based on the paper Unit-Scaled Maximal Update Parametrization and previous work Unit Scaling: Out-of-the-Box Low-Precision Training .
-Documentation can be found at
-https://graphcore-research.github.io/unit-scaling and an example notebook at examples/demo.ipynb .
-Note: The library is currently in its beta release.
-Some features have yet to be implemented and occasional bugs may be present.
-We’re keen to help users with any problems they encounter.
-
-2.1. Installation
-To install the unit-scaling
library, run:
-pip install git + https : // github . com / graphcore - research / unit - scaling . git
-
-
-
-
-
-2.3. Development
+
+2. Development
For users who wish to develop using this codebase, the following setup is required:
First-time setup :
python3 -m venv .venv
@@ -137,12 +107,6 @@ 2.3. Development
then view docs/_build/html/index.html
in your browser.
-
-
-2.4. License
-Copyright (c) 2023 Graphcore Ltd. Licensed under the Apache 2.0 License.
-See NOTICE.md for further details.
-
diff --git a/index.html b/index.html
index b6d143d..05b3799 100644
--- a/index.html
+++ b/index.html
@@ -122,13 +122,7 @@ 2. Developer guide
-
+2. Developer guide
3. Limitations
4. Blog
4.1. Almost scaled dot-product self attention
diff --git a/limitations.html b/limitations.html
index e5b8b73..fa073b6 100644
--- a/limitations.html
+++ b/limitations.html
@@ -19,7 +19,7 @@
-
+
@@ -96,7 +96,7 @@ 3. Limitations
diff --git a/objects.inv b/objects.inv
index aa34de3..f9ee50f 100644
Binary files a/objects.inv and b/objects.inv differ
diff --git a/searchindex.js b/searchindex.js
index 6af6e30..1c044a5 100644
--- a/searchindex.js
+++ b/searchindex.js
@@ -1 +1 @@
-Search.setIndex({"docnames": ["api_reference", "blog", "development", "generated/unit_scaling", "generated/unit_scaling.CrossEntropyLoss", "generated/unit_scaling.DepthModuleList", "generated/unit_scaling.DepthSequential", "generated/unit_scaling.Dropout", "generated/unit_scaling.Embedding", "generated/unit_scaling.GELU", "generated/unit_scaling.LayerNorm", "generated/unit_scaling.Linear", "generated/unit_scaling.LinearReadout", "generated/unit_scaling.MHSA", "generated/unit_scaling.MLP", "generated/unit_scaling.Parameter", "generated/unit_scaling.RMSNorm", "generated/unit_scaling.SiLU", "generated/unit_scaling.Softmax", "generated/unit_scaling.TransformerDecoder", "generated/unit_scaling.TransformerLayer", "generated/unit_scaling.analysis", "generated/unit_scaling.analysis.example_batch", "generated/unit_scaling.analysis.graph_to_dataframe", "generated/unit_scaling.analysis.plot", "generated/unit_scaling.analysis.visualiser", "generated/unit_scaling.constraints", "generated/unit_scaling.constraints.amean", "generated/unit_scaling.constraints.apply_constraint", "generated/unit_scaling.constraints.gmean", "generated/unit_scaling.constraints.hmean", "generated/unit_scaling.constraints.to_grad_input_scale", "generated/unit_scaling.constraints.to_left_grad_scale", "generated/unit_scaling.constraints.to_output_scale", "generated/unit_scaling.constraints.to_right_grad_scale", "generated/unit_scaling.core", "generated/unit_scaling.core.functional", "generated/unit_scaling.core.functional.logarithmic_interpolation", "generated/unit_scaling.core.functional.rms", "generated/unit_scaling.core.functional.scale_elementwise", "generated/unit_scaling.core.functional.transformer_residual_scaling_rule", "generated/unit_scaling.formats", "generated/unit_scaling.formats.FPFormat", "generated/unit_scaling.formats.format_to_tuple", "generated/unit_scaling.formats.tuple_to_format", "generated/unit_scaling.functional", "generated/unit_scaling.functional.add", "generated/unit_scaling.functional.cross_entropy", "generated/unit_scaling.functional.dropout", "generated/unit_scaling.functional.embedding", "generated/unit_scaling.functional.gelu", "generated/unit_scaling.functional.layer_norm", "generated/unit_scaling.functional.linear", "generated/unit_scaling.functional.linear_readout", "generated/unit_scaling.functional.matmul", "generated/unit_scaling.functional.mse_loss", "generated/unit_scaling.functional.residual_add", "generated/unit_scaling.functional.residual_apply", "generated/unit_scaling.functional.residual_split", "generated/unit_scaling.functional.rms_norm", "generated/unit_scaling.functional.scaled_dot_product_attention", "generated/unit_scaling.functional.silu", "generated/unit_scaling.functional.silu_glu", "generated/unit_scaling.functional.softmax", "generated/unit_scaling.optim", "generated/unit_scaling.optim.Adam", "generated/unit_scaling.optim.AdamW", "generated/unit_scaling.optim.SGD", "generated/unit_scaling.optim.lr_scale_for_depth", "generated/unit_scaling.optim.lr_scale_func_adam", "generated/unit_scaling.optim.lr_scale_func_sgd", "generated/unit_scaling.optim.scaled_parameters", "generated/unit_scaling.parameter", "generated/unit_scaling.parameter.OrderedDict", "generated/unit_scaling.parameter.Parameter", "generated/unit_scaling.parameter.ParameterData", "generated/unit_scaling.parameter.Protocol", "generated/unit_scaling.parameter.Tensor", "generated/unit_scaling.parameter.has_parameter_data", "generated/unit_scaling.scale", "generated/unit_scaling.scale.scale_bwd", "generated/unit_scaling.scale.scale_fwd", "generated/unit_scaling.transformer_residual_scaling_rule", "generated/unit_scaling.transforms", "generated/unit_scaling.transforms.Metrics", "generated/unit_scaling.transforms.compile", "generated/unit_scaling.transforms.prune_non_float_tensors", "generated/unit_scaling.transforms.prune_same_scale_tensors", "generated/unit_scaling.transforms.prune_selected_nodes", "generated/unit_scaling.transforms.simulate_format", "generated/unit_scaling.transforms.simulate_fp8", "generated/unit_scaling.transforms.track_scales", "generated/unit_scaling.transforms.unit_scale", "generated/unit_scaling.transforms.utils", "generated/unit_scaling.transforms.utils.apply_transform", "generated/unit_scaling.transforms.utils.patch_to_expand_modules", "generated/unit_scaling.transforms.utils.replace_node_with_function", "generated/unit_scaling.transforms.utils.torch_nn_modules_to_user_modules", "generated/unit_scaling.utils", "generated/unit_scaling.utils.ScalePair", "generated/unit_scaling.utils.ScaleTracker", "generated/unit_scaling.utils.ScaleTrackingInterpreter", "generated/unit_scaling.utils.analyse_module", "generated/unit_scaling.visualiser", "index", "limitations", "posts/almost_scaled_dot_product_attention", "user_guide"], "filenames": ["api_reference.rst", "blog.rst", "development.rst", "generated/unit_scaling.rst", "generated/unit_scaling.CrossEntropyLoss.rst", "generated/unit_scaling.DepthModuleList.rst", "generated/unit_scaling.DepthSequential.rst", "generated/unit_scaling.Dropout.rst", "generated/unit_scaling.Embedding.rst", "generated/unit_scaling.GELU.rst", "generated/unit_scaling.LayerNorm.rst", "generated/unit_scaling.Linear.rst", "generated/unit_scaling.LinearReadout.rst", "generated/unit_scaling.MHSA.rst", "generated/unit_scaling.MLP.rst", "generated/unit_scaling.Parameter.rst", "generated/unit_scaling.RMSNorm.rst", "generated/unit_scaling.SiLU.rst", "generated/unit_scaling.Softmax.rst", "generated/unit_scaling.TransformerDecoder.rst", "generated/unit_scaling.TransformerLayer.rst", "generated/unit_scaling.analysis.rst", "generated/unit_scaling.analysis.example_batch.rst", "generated/unit_scaling.analysis.graph_to_dataframe.rst", "generated/unit_scaling.analysis.plot.rst", "generated/unit_scaling.analysis.visualiser.rst", "generated/unit_scaling.constraints.rst", "generated/unit_scaling.constraints.amean.rst", "generated/unit_scaling.constraints.apply_constraint.rst", "generated/unit_scaling.constraints.gmean.rst", "generated/unit_scaling.constraints.hmean.rst", "generated/unit_scaling.constraints.to_grad_input_scale.rst", "generated/unit_scaling.constraints.to_left_grad_scale.rst", "generated/unit_scaling.constraints.to_output_scale.rst", "generated/unit_scaling.constraints.to_right_grad_scale.rst", "generated/unit_scaling.core.rst", "generated/unit_scaling.core.functional.rst", "generated/unit_scaling.core.functional.logarithmic_interpolation.rst", "generated/unit_scaling.core.functional.rms.rst", "generated/unit_scaling.core.functional.scale_elementwise.rst", "generated/unit_scaling.core.functional.transformer_residual_scaling_rule.rst", "generated/unit_scaling.formats.rst", "generated/unit_scaling.formats.FPFormat.rst", "generated/unit_scaling.formats.format_to_tuple.rst", "generated/unit_scaling.formats.tuple_to_format.rst", "generated/unit_scaling.functional.rst", "generated/unit_scaling.functional.add.rst", "generated/unit_scaling.functional.cross_entropy.rst", "generated/unit_scaling.functional.dropout.rst", "generated/unit_scaling.functional.embedding.rst", "generated/unit_scaling.functional.gelu.rst", "generated/unit_scaling.functional.layer_norm.rst", "generated/unit_scaling.functional.linear.rst", "generated/unit_scaling.functional.linear_readout.rst", "generated/unit_scaling.functional.matmul.rst", "generated/unit_scaling.functional.mse_loss.rst", "generated/unit_scaling.functional.residual_add.rst", "generated/unit_scaling.functional.residual_apply.rst", "generated/unit_scaling.functional.residual_split.rst", "generated/unit_scaling.functional.rms_norm.rst", "generated/unit_scaling.functional.scaled_dot_product_attention.rst", "generated/unit_scaling.functional.silu.rst", "generated/unit_scaling.functional.silu_glu.rst", "generated/unit_scaling.functional.softmax.rst", "generated/unit_scaling.optim.rst", "generated/unit_scaling.optim.Adam.rst", "generated/unit_scaling.optim.AdamW.rst", "generated/unit_scaling.optim.SGD.rst", "generated/unit_scaling.optim.lr_scale_for_depth.rst", "generated/unit_scaling.optim.lr_scale_func_adam.rst", "generated/unit_scaling.optim.lr_scale_func_sgd.rst", "generated/unit_scaling.optim.scaled_parameters.rst", "generated/unit_scaling.parameter.rst", "generated/unit_scaling.parameter.OrderedDict.rst", "generated/unit_scaling.parameter.Parameter.rst", "generated/unit_scaling.parameter.ParameterData.rst", "generated/unit_scaling.parameter.Protocol.rst", "generated/unit_scaling.parameter.Tensor.rst", "generated/unit_scaling.parameter.has_parameter_data.rst", "generated/unit_scaling.scale.rst", "generated/unit_scaling.scale.scale_bwd.rst", "generated/unit_scaling.scale.scale_fwd.rst", "generated/unit_scaling.transformer_residual_scaling_rule.rst", "generated/unit_scaling.transforms.rst", "generated/unit_scaling.transforms.Metrics.rst", "generated/unit_scaling.transforms.compile.rst", "generated/unit_scaling.transforms.prune_non_float_tensors.rst", "generated/unit_scaling.transforms.prune_same_scale_tensors.rst", "generated/unit_scaling.transforms.prune_selected_nodes.rst", "generated/unit_scaling.transforms.simulate_format.rst", "generated/unit_scaling.transforms.simulate_fp8.rst", "generated/unit_scaling.transforms.track_scales.rst", "generated/unit_scaling.transforms.unit_scale.rst", "generated/unit_scaling.transforms.utils.rst", "generated/unit_scaling.transforms.utils.apply_transform.rst", "generated/unit_scaling.transforms.utils.patch_to_expand_modules.rst", "generated/unit_scaling.transforms.utils.replace_node_with_function.rst", "generated/unit_scaling.transforms.utils.torch_nn_modules_to_user_modules.rst", "generated/unit_scaling.utils.rst", "generated/unit_scaling.utils.ScalePair.rst", "generated/unit_scaling.utils.ScaleTracker.rst", "generated/unit_scaling.utils.ScaleTrackingInterpreter.rst", "generated/unit_scaling.utils.analyse_module.rst", "generated/unit_scaling.visualiser.rst", "index.rst", "limitations.rst", "posts/almost_scaled_dot_product_attention.md", "user_guide.rst"], "titles": ["5. API reference", "4. Unit Scaling blog", "2. Unit-Scaled Maximal Update Parameterization (u-\u03bcP)", "5.1. unit_scaling", "5.1.4. unit_scaling.CrossEntropyLoss", "5.1.5. unit_scaling.DepthModuleList", "5.1.6. unit_scaling.DepthSequential", "5.1.7. unit_scaling.Dropout", "5.1.8. unit_scaling.Embedding", "5.1.9. unit_scaling.GELU", "5.1.10. unit_scaling.LayerNorm", "5.1.11. unit_scaling.Linear", "5.1.12. unit_scaling.LinearReadout", "5.1.13. unit_scaling.MHSA", "5.1.14. unit_scaling.MLP", "5.1.1. unit_scaling.Parameter", "5.1.15. unit_scaling.RMSNorm", "5.1.16. unit_scaling.SiLU", "5.1.17. unit_scaling.Softmax", "5.1.18. unit_scaling.TransformerDecoder", "5.1.19. unit_scaling.TransformerLayer", "5.2. unit_scaling.analysis", "5.2.1. unit_scaling.analysis.example_batch", "5.2.2. unit_scaling.analysis.graph_to_dataframe", "5.2.3. unit_scaling.analysis.plot", "5.2.4. unit_scaling.analysis.visualiser", "5.3. unit_scaling.constraints", "5.3.1. unit_scaling.constraints.amean", "5.3.2. unit_scaling.constraints.apply_constraint", "5.3.3. unit_scaling.constraints.gmean", "5.3.4. unit_scaling.constraints.hmean", "5.3.5. unit_scaling.constraints.to_grad_input_scale", "5.3.6. unit_scaling.constraints.to_left_grad_scale", "5.3.7. unit_scaling.constraints.to_output_scale", "5.3.8. unit_scaling.constraints.to_right_grad_scale", "5.1.20. unit_scaling.core", "5.1.20.1. unit_scaling.core.functional", "5.1.20.1.1. unit_scaling.core.functional.logarithmic_interpolation", "5.1.20.1.2. unit_scaling.core.functional.rms", "5.1.20.1.3. unit_scaling.core.functional.scale_elementwise", "5.1.20.1.4. unit_scaling.core.functional.transformer_residual_scaling_rule", "5.4. unit_scaling.formats", "5.4.3. unit_scaling.formats.FPFormat", "5.4.1. unit_scaling.formats.format_to_tuple", "5.4.2. unit_scaling.formats.tuple_to_format", "5.1.21. unit_scaling.functional", "5.1.21.1. unit_scaling.functional.add", "5.1.21.2. unit_scaling.functional.cross_entropy", "5.1.21.3. unit_scaling.functional.dropout", "5.1.21.4. unit_scaling.functional.embedding", "5.1.21.5. unit_scaling.functional.gelu", "5.1.21.6. unit_scaling.functional.layer_norm", "5.1.21.7. unit_scaling.functional.linear", "5.1.21.8. unit_scaling.functional.linear_readout", "5.1.21.9. unit_scaling.functional.matmul", "5.1.21.10. unit_scaling.functional.mse_loss", "5.1.21.11. unit_scaling.functional.residual_add", "5.1.21.12. unit_scaling.functional.residual_apply", "5.1.21.13. unit_scaling.functional.residual_split", "5.1.21.14. unit_scaling.functional.rms_norm", "5.1.21.15. unit_scaling.functional.scaled_dot_product_attention", "5.1.21.16. unit_scaling.functional.silu", "5.1.21.17. unit_scaling.functional.silu_glu", "5.1.21.18. unit_scaling.functional.softmax", "5.1.22. unit_scaling.optim", "5.1.22.5. unit_scaling.optim.Adam", "5.1.22.6. unit_scaling.optim.AdamW", "5.1.22.7. unit_scaling.optim.SGD", "5.1.22.1. unit_scaling.optim.lr_scale_for_depth", "5.1.22.2. unit_scaling.optim.lr_scale_func_adam", "5.1.22.3. unit_scaling.optim.lr_scale_func_sgd", "5.1.22.4. unit_scaling.optim.scaled_parameters", "5.1.23. unit_scaling.parameter", "5.1.23.3. unit_scaling.parameter.OrderedDict", "5.1.23.1. unit_scaling.parameter.Parameter", "5.1.23.4. unit_scaling.parameter.ParameterData", "5.1.23.5. unit_scaling.parameter.Protocol", "5.1.23.6. unit_scaling.parameter.Tensor", "5.1.23.2. unit_scaling.parameter.has_parameter_data", "5.5. unit_scaling.scale", "5.5.1. unit_scaling.scale.scale_bwd", "5.5.2. unit_scaling.scale.scale_fwd", "5.1.2. unit_scaling.transformer_residual_scaling_rule", "5.6. unit_scaling.transforms", "5.6.9. unit_scaling.transforms.Metrics", "5.6.1. unit_scaling.transforms.compile", "5.6.2. unit_scaling.transforms.prune_non_float_tensors", "5.6.3. unit_scaling.transforms.prune_same_scale_tensors", "5.6.4. unit_scaling.transforms.prune_selected_nodes", "5.6.5. unit_scaling.transforms.simulate_format", "5.6.6. unit_scaling.transforms.simulate_fp8", "5.6.7. unit_scaling.transforms.track_scales", "5.6.8. unit_scaling.transforms.unit_scale", "5.7. unit_scaling.transforms.utils", "5.7.1. unit_scaling.transforms.utils.apply_transform", "5.7.2. unit_scaling.transforms.utils.patch_to_expand_modules", "5.7.3. unit_scaling.transforms.utils.replace_node_with_function", "5.7.4. unit_scaling.transforms.utils.torch_nn_modules_to_user_modules", "5.8. unit_scaling.utils", "5.8.2. unit_scaling.utils.ScalePair", "5.8.3. unit_scaling.utils.ScaleTracker", "5.8.4. unit_scaling.utils.ScaleTrackingInterpreter", "5.8.1. unit_scaling.utils.analyse_module", "5.1.3. unit_scaling.visualiser", "Unit Scaling", "3. Limitations", "Almost-scaled dot-product attention", "1. User guide"], "terms": {"unit": [0, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 26, 29, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 71, 83, 85, 89, 90, 91, 92, 98, 102, 105, 106], "scale": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 39, 40, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 77, 82, 83, 85, 86, 87, 89, 90, 91, 92, 98, 102, 103, 105], "i": [0, 4, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17, 18, 19, 22, 23, 24, 25, 28, 33, 40, 47, 48, 49, 50, 52, 53, 54, 58, 60, 61, 62, 63, 65, 66, 67, 73, 74, 77, 82, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 97, 100, 101, 103, 104, 105, 106], "implement": [0, 2, 7, 10, 13, 14, 16, 18, 19, 20, 25, 36, 60, 65, 66, 67, 77, 85, 92, 94, 102, 103, 104, 107], "us": [0, 2, 4, 6, 8, 9, 10, 11, 12, 13, 14, 19, 20, 22, 23, 25, 26, 28, 30, 42, 48, 49, 54, 56, 57, 58, 60, 63, 65, 66, 67, 76, 77, 83, 84, 85, 87, 89, 90, 91, 92, 94, 95, 97, 100, 101, 102, 103, 104, 105, 106, 107], "thin": 0, "wrapper": [0, 64, 92, 94], "around": [0, 77, 92, 94, 105], "exist": [0, 73, 77, 92, 107], "torch": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 17, 18, 23, 24, 25, 45, 46, 47, 49, 54, 60, 63, 65, 66, 67, 69, 70, 71, 72, 74, 75, 77, 83, 85, 89, 90, 91, 92, 94, 95, 97, 100, 102, 103, 105, 107], "nn": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 17, 18, 19, 23, 24, 25, 45, 49, 60, 65, 66, 67, 72, 74, 75, 77, 85, 89, 90, 92, 94, 95, 97, 100, 101, 102, 103, 107], "class": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 41, 42, 47, 49, 60, 63, 64, 65, 66, 67, 72, 73, 75, 76, 77, 83, 84, 92, 97, 98, 99, 100, 101, 102, 107], "function": [0, 3, 4, 6, 7, 9, 11, 12, 13, 14, 17, 18, 19, 20, 21, 23, 24, 26, 28, 35, 41, 64, 65, 66, 67, 72, 77, 79, 82, 83, 85, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 107], "document": [0, 2, 8, 18, 107], "also": [0, 4, 15, 17, 57, 61, 65, 66, 74, 75, 77, 91, 92, 100, 106, 107], "inherit": 0, "from": [0, 5, 7, 8, 10, 11, 12, 18, 22, 24, 25, 48, 49, 65, 66, 67, 73, 77, 85, 86, 87, 89, 90, 91, 92, 94, 100, 101, 102, 103, 107], "standard": [0, 10, 18, 25, 41, 77, 85, 90, 92, 94, 99, 100, 101, 102, 103, 107], "pytorch": [0, 2, 21, 60, 77, 85, 92, 107], "doc": [0, 2, 105], "modif": [0, 5, 6, 77, 100], "note": [0, 2, 4, 5, 6, 8, 10, 11, 12, 16, 19, 47, 49, 54, 60, 65, 66, 67, 77, 85, 91, 92, 94, 100, 107], "some": [0, 2, 4, 47, 52, 53, 54, 60, 65, 66, 67, 77, 92, 104, 106, 107], "mai": [0, 2, 4, 25, 52, 53, 54, 58, 60, 65, 66, 67, 77, 85, 87, 91, 92, 100, 103, 104, 105, 107], "longer": [0, 77], "relev": [0, 10, 24, 25, 71, 103], "ar": [0, 1, 4, 5, 6, 7, 10, 11, 12, 16, 18, 24, 25, 40, 47, 49, 54, 58, 60, 65, 66, 67, 73, 76, 77, 82, 84, 87, 89, 90, 91, 92, 94, 100, 101, 103, 106, 107], "nevertheless": 0, "The": [0, 2, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 23, 24, 28, 39, 40, 42, 46, 47, 49, 50, 52, 53, 54, 58, 60, 61, 63, 65, 66, 67, 74, 77, 82, 85, 86, 87, 89, 90, 91, 92, 94, 100, 101, 104, 106, 107], "built": [0, 102], "mirror": [0, 77], "close": 0, "possibl": [0, 49, 77, 107], "can": [0, 2, 4, 5, 6, 7, 8, 24, 58, 60, 65, 66, 67, 76, 77, 85, 87, 91, 92, 100, 101, 104, 105, 106, 107], "easili": 0, "swap": [0, 77, 107], "out": [0, 2, 7, 11, 12, 46, 52, 53, 54, 56, 64, 77, 100, 104, 106, 107], "equival": [0, 4, 8, 10, 60, 77, 89, 90, 92, 100, 102, 107], "For": [0, 2, 8, 10, 25, 37, 40, 54, 60, 65, 66, 67, 71, 77, 82, 91, 100, 103, 104, 106, 107], "code": [0, 6, 25, 77, 102, 103, 107], "which": [0, 4, 6, 10, 14, 18, 25, 58, 60, 63, 65, 66, 67, 71, 77, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 100, 101, 103, 106, 107], "follow": [0, 2, 4, 25, 39, 54, 60, 65, 66, 67, 73, 77, 85, 91, 95, 100, 103, 104, 106, 107], "import": [0, 24, 67, 85, 86, 87, 91, 92, 106, 107], "f": [0, 39, 47, 49, 60, 65, 66, 67, 73, 77, 107], "appli": [0, 4, 6, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 28, 39, 46, 47, 48, 50, 51, 52, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 64, 71, 77, 80, 81, 85, 86, 87, 92, 94, 95, 100, 106, 107], "first": [0, 2, 6, 8, 39, 54, 65, 66, 77, 85, 94, 100, 101, 107], "ad": [0, 6, 10, 16, 60, 65, 66, 67, 77, 91, 105, 106, 107], "unit_sc": [0, 104, 105, 107], "uu": [0, 15, 71, 74, 107], "u": [0, 11, 12, 15, 40, 64, 65, 66, 67, 71, 72, 74, 75, 82, 92, 102, 107], "replac": [0, 77, 92, 96, 107], "letter": 0, "those": [0, 4, 47, 77, 89, 90, 104, 107], "assum": [0, 31, 32, 33, 34, 106, 107], "thei": [0, 1, 2, 6, 24, 25, 65, 66, 67, 76, 77, 92, 100, 103, 104, 106, 107], "support": [0, 2, 4, 7, 8, 11, 12, 15, 25, 46, 47, 48, 49, 52, 53, 54, 60, 65, 66, 67, 74, 75, 77, 78, 85, 92, 94, 100, 103], "click": 0, "below": [0, 10, 18, 47, 107], "full": [0, 22, 65, 66, 67, 77, 91, 100, 106, 107], "transform": [1, 6, 10, 11, 12, 19, 20, 23, 24, 25, 39, 40, 52, 53, 64, 65, 66, 67, 77, 82, 103, 104, 105, 106, 107], "seem": [1, 106], "all": [1, 5, 8, 11, 12, 23, 24, 25, 40, 49, 60, 63, 65, 66, 67, 73, 77, 82, 86, 87, 88, 92, 100, 103, 106, 107], "you": [1, 4, 52, 53, 54, 60, 65, 66, 67, 77, 100, 101, 105, 106, 107], "need": [1, 4, 58, 77, 85, 91, 100, 102, 106, 107], "we": [1, 2, 25, 65, 66, 67, 77, 87, 90, 92, 94, 100, 101, 103, 104, 105, 106, 107], "don": [1, 24, 25, 65, 66, 67, 86, 87, 103, 106, 107], "t": [1, 8, 24, 25, 44, 49, 52, 53, 65, 66, 67, 76, 77, 85, 86, 87, 92, 94, 95, 97, 100, 103, 105, 106, 107], "fulli": [1, 101, 106], "understand": [1, 106, 107], "why": [1, 60, 106, 107], "work": [1, 2, 25, 60, 77, 85, 93, 102, 103, 104, 105, 107], "so": [1, 18, 60, 63, 65, 66, 67, 77, 85, 91, 92, 94, 100, 101, 106, 107], "well": [1, 4, 65, 66, 67, 77, 105, 106, 107], "while": [1, 4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77, 100, 106], "notic": [1, 2, 52, 53, 54, 106], "someth": [1, 65, 66, 67, 106, 107], "surpris": [1, 106], "about": [1, 77, 84, 106], "heart": [1, 106], "architectur": [1, 4, 47, 106], "how": [1, 2, 77, 100, 104, 106], "output": [1, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 24, 25, 31, 32, 33, 34, 39, 46, 47, 49, 50, 52, 53, 54, 60, 61, 62, 63, 74, 77, 86, 100, 101, 102, 103, 106, 107], "dougla": [1, 106], "orr": [1, 106], "octob": [1, 106], "2023": [1, 2, 104, 106, 107], "A": [2, 4, 5, 6, 7, 8, 13, 14, 18, 19, 20, 47, 49, 54, 60, 63, 65, 66, 67, 77, 84, 85, 104, 106, 107], "librari": [2, 35, 91, 92, 104, 105, 107], "base": [2, 25, 52, 53, 60, 67, 71, 76, 100, 103, 107], "paper": [2, 7, 10, 16, 40, 65, 66, 82, 92, 104, 106, 107], "parametr": 2, "previou": [2, 77, 87, 92], "box": [2, 64, 101, 104, 107], "low": [2, 4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77, 104, 107], "precis": [2, 11, 12, 54, 60, 104, 107], "train": [2, 4, 7, 8, 10, 48, 49, 58, 60, 65, 66, 67, 104, 105, 106, 107], "found": [2, 73, 104, 106, 107], "http": [2, 20, 77, 104, 107], "graphcor": [2, 104, 106, 107], "research": [2, 104, 107], "github": [2, 77, 104, 107], "io": 2, "an": [2, 4, 5, 6, 7, 8, 10, 11, 12, 14, 15, 16, 18, 22, 23, 24, 25, 33, 39, 40, 48, 49, 54, 58, 60, 65, 66, 67, 71, 73, 74, 77, 82, 86, 87, 88, 89, 91, 92, 99, 100, 101, 103, 106, 107], "exampl": [2, 4, 6, 7, 8, 9, 10, 11, 12, 17, 18, 19, 37, 46, 47, 49, 54, 60, 65, 66, 67, 71, 76, 77, 100, 101, 102, 105, 106, 107], "notebook": [2, 104, 106, 107], "demo": 2, "ipynb": 2, "current": [2, 19, 25, 60, 65, 66, 67, 77, 85, 92, 94, 101, 103, 104, 105, 107], "its": [2, 39, 54, 65, 66, 77, 81, 84, 85, 96, 100, 104, 107], "beta": [2, 10, 52, 53, 54, 60, 65, 66, 77, 104, 107], "releas": [2, 77, 85, 104, 107], "featur": [2, 7, 52, 53, 54, 104, 105, 107], "have": [2, 4, 8, 23, 24, 49, 52, 53, 54, 58, 65, 66, 67, 77, 86, 87, 89, 90, 91, 92, 100, 104, 105, 106, 107], "yet": [2, 65, 66, 77, 92, 104, 106, 107], "occasion": [2, 104, 107], "bug": [2, 104, 105, 107], "present": [2, 73, 77, 96, 104, 107], "re": [2, 25, 63, 77, 103, 104, 105, 106, 107], "keen": [2, 104, 105, 107], "help": [2, 104, 105, 106, 107], "user": [2, 25, 60, 65, 66, 67, 77, 89, 90, 91, 92, 94, 97, 100, 103, 104], "ani": [2, 6, 7, 8, 9, 10, 11, 12, 17, 18, 24, 25, 33, 52, 53, 60, 64, 65, 66, 67, 71, 77, 88, 91, 92, 94, 95, 96, 100, 101, 102, 103, 104, 105, 107], "problem": [2, 4, 104, 107], "encount": [2, 104, 107], "To": [2, 60, 77, 97, 101, 104, 107], "run": [2, 6, 23, 24, 60, 65, 66, 67, 77, 89, 90, 101, 104, 107], "pip": [2, 104, 107], "git": [2, 104, 107], "com": [2, 77, 104, 107], "demonstr": [2, 19, 104], "overview": [2, 104, 107], "see": [2, 4, 8, 9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 60, 61, 63, 76, 77, 92, 100, 101, 104, 105, 106, 107], "fp8": [2, 90, 104, 107], "show": [2, 24, 104, 107], "nanogpt": [2, 104], "model": [2, 6, 21, 24, 25, 58, 65, 66, 67, 71, 86, 87, 89, 90, 91, 92, 98, 103, 104, 106], "more": [2, 8, 49, 60, 61, 63, 65, 66, 67, 77, 91, 100, 105, 106, 107], "depth": [2, 5, 6, 40, 68, 82, 107], "explan": [2, 106], "consult": 2, "our": [2, 65, 66, 92, 94, 100, 104, 105, 106, 107], "And": 2, "practic": [2, 24, 77, 91, 95, 107], "introduct": [2, 107], "guid": [2, 92, 104], "who": [2, 104, 105, 107], "wish": [2, 94, 104, 107], "thi": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 25, 28, 39, 40, 42, 46, 48, 49, 50, 52, 53, 54, 58, 60, 61, 63, 65, 66, 67, 71, 73, 75, 77, 82, 85, 86, 87, 89, 90, 91, 92, 94, 95, 97, 100, 101, 103, 104, 105, 106, 107], "codebas": [2, 104, 107], "setup": 2, "requir": [2, 4, 60, 77, 91, 92, 100, 107], "time": [2, 10, 46, 54, 65, 66, 67, 77, 107], "python3": [2, 102], "m": [2, 7, 9, 11, 12, 17, 18, 54, 85, 89, 90, 91, 92, 94], "venv": 2, "sourc": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 77, 78, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 97, 99, 100, 101, 102, 103], "bin": [2, 77], "activ": [2, 10, 17, 61, 89, 106, 107], "r": [2, 8, 49, 77, 100], "dev": 2, "txt": 2, "Or": 2, "ipu": [2, 77], "subsequ": [2, 6], "pre": [2, 65, 66, 67, 101, 107], "flight": 2, "check": [2, 71, 76, 77, 78, 100, 106, 107], "command": 2, "id": [2, 22, 65, 66, 67], "recommend": [2, 25, 29, 77, 103, 104, 107], "python": [2, 5, 77, 102], "intepret": 2, "set": [2, 4, 7, 10, 11, 12, 16, 40, 47, 48, 60, 65, 66, 67, 71, 73, 77, 82, 84, 88, 91, 92, 100, 107], "format": [2, 77, 89, 90, 104, 107], "save": [2, 65, 66, 67, 77, 100], "enabl": [2, 60, 67, 71, 77, 79, 92, 94, 100, 107], "consid": [2, 4, 77, 106, 107], "env": [2, 100], "file": 2, "pythonpath": 2, "echo": 2, "pwd": 2, "differ": [2, 6, 11, 12, 25, 49, 54, 58, 60, 65, 66, 67, 77, 79, 85, 94, 103, 106, 107], "path": [2, 22, 25, 103], "devcontain": 2, "cd": 2, "make": [2, 24, 60, 77, 92, 94, 97, 106, 107], "html": 2, "view": [2, 24, 73, 77], "_build": 2, "index": [2, 4, 5, 8, 40, 49, 77, 82, 100], "your": [2, 77, 100, 107], "browser": 2, "copyright": 2, "c": [2, 4, 10, 46, 47, 60, 76, 77, 100], "ltd": 2, "under": [2, 49, 106], "apach": 2, "2": [2, 7, 8, 9, 10, 16, 17, 18, 38, 46, 49, 50, 52, 53, 54, 60, 65, 66, 67, 77, 87, 100, 102, 107], "0": [2, 4, 7, 8, 9, 10, 13, 17, 18, 19, 20, 37, 38, 40, 42, 46, 47, 48, 49, 50, 52, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 71, 76, 77, 82, 100, 102, 106, 107], "md": 2, "further": [2, 65, 66, 77, 106, 107], "detail": [2, 8, 9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 61, 63, 65, 66, 76, 77, 92, 100, 101, 105, 107], "version": [3, 25, 39, 45, 54, 65, 66, 67, 77, 85, 91, 92, 94, 95, 97, 103, 107], "common": [3, 26, 45, 46, 65, 66, 67, 105, 107], "modul": [3, 5, 6, 7, 8, 10, 11, 12, 16, 19, 23, 24, 25, 28, 35, 49, 54, 60, 77, 83, 85, 86, 87, 89, 90, 91, 92, 94, 95, 97, 101, 102, 103, 106, 107], "mult": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63], "float": [4, 7, 8, 9, 10, 13, 16, 17, 18, 19, 20, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 42, 46, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 68, 69, 70, 71, 77, 80, 81, 82, 84, 86, 87, 91, 99, 101, 103, 107], "1": [4, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 37, 40, 46, 47, 49, 50, 54, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 71, 77, 82, 87, 100, 102, 104, 107], "weight": [4, 8, 10, 11, 12, 15, 16, 17, 19, 20, 37, 47, 49, 51, 52, 53, 56, 57, 58, 59, 60, 61, 65, 66, 67, 71, 74, 77, 89, 92, 102, 106, 107], "tensor": [4, 8, 11, 12, 15, 16, 18, 22, 24, 25, 38, 39, 42, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 71, 74, 80, 81, 84, 86, 87, 91, 94, 99, 100, 101, 102, 103, 106, 107], "none": [4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 24, 28, 38, 39, 46, 47, 49, 50, 51, 52, 53, 54, 55, 59, 60, 61, 63, 65, 66, 67, 71, 73, 74, 77, 96, 97, 99, 100, 101, 102, 107], "size_averag": [4, 47, 55], "bool": [4, 7, 8, 10, 11, 12, 13, 16, 17, 20, 24, 25, 38, 47, 48, 49, 55, 60, 61, 65, 66, 67, 71, 77, 96, 100, 101, 102, 103], "ignore_index": [4, 47], "int": [4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 25, 38, 40, 42, 43, 44, 46, 47, 49, 51, 59, 63, 65, 66, 67, 74, 76, 77, 82, 84, 100, 101, 103, 107], "100": [4, 37, 47, 77], "reduc": [4, 47, 55, 77, 89, 90], "reduct": [4, 47, 55, 77], "str": [4, 9, 11, 12, 13, 14, 17, 18, 19, 20, 22, 24, 25, 28, 39, 42, 46, 47, 50, 52, 53, 54, 55, 61, 63, 65, 66, 67, 71, 77, 88, 101, 102, 103], "mean": [4, 7, 9, 10, 11, 12, 16, 17, 18, 24, 27, 29, 30, 38, 47, 52, 53, 55, 77, 85, 107], "label_smooth": [4, 47], "comput": [4, 7, 8, 10, 18, 27, 29, 30, 38, 40, 47, 49, 55, 60, 63, 65, 66, 77, 82, 100, 107], "cross": [4, 47, 77, 107], "entropi": [4, 47, 107], "loss": [4, 24, 25, 47, 65, 66, 67, 86, 87, 91, 103, 107], "between": [4, 6, 24, 37, 40, 47, 58, 65, 66, 67, 77, 82, 107], "input": [4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 23, 24, 25, 31, 32, 33, 34, 39, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 77, 80, 81, 87, 90, 91, 92, 100, 101, 102, 103, 106, 107], "logit": [4, 47, 77], "target": [4, 47, 55, 67, 77, 88, 96, 101, 107], "It": [4, 6, 25, 63, 65, 66, 67, 77, 100, 103, 106, 107], "when": [4, 6, 9, 10, 11, 12, 16, 18, 47, 50, 54, 60, 65, 66, 67, 77, 89, 90, 91, 95, 97, 101, 106, 107], "classif": 4, "If": [4, 7, 8, 10, 11, 12, 40, 47, 48, 49, 52, 53, 54, 60, 63, 65, 66, 67, 73, 77, 82, 89, 90, 100, 101, 106], "provid": [4, 6, 25, 27, 29, 30, 31, 32, 33, 34, 60, 64, 65, 66, 67, 73, 77, 87, 91, 92, 103, 105, 107], "option": [4, 8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 22, 24, 25, 28, 39, 40, 46, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 71, 77, 82, 84, 87, 92, 94, 96, 100, 101, 102, 103], "argument": [4, 9, 50, 54, 60, 65, 66, 67, 77, 100, 101, 107], "should": [4, 18, 39, 47, 56, 58, 60, 65, 66, 67, 75, 77, 89, 90, 91, 92, 94, 95, 100, 101, 102, 107], "1d": 4, "assign": [4, 77], "each": [4, 6, 7, 8, 10, 11, 12, 24, 25, 47, 49, 52, 60, 65, 66, 67, 77, 89, 90, 91, 100, 101, 102, 103, 107], "particularli": [4, 107], "unbalanc": 4, "expect": [4, 10, 16, 49, 77, 89, 90, 107], "contain": [4, 5, 6, 8, 23, 47, 49, 65, 66, 67, 77, 84, 86, 89, 90, 91, 96, 99, 100, 106, 107], "unnorm": [4, 47], "do": [4, 7, 8, 48, 49, 65, 66, 67, 77, 100, 101, 107], "posit": [4, 19, 77, 101], "sum": [4, 16, 18, 47, 63, 77, 106], "gener": [4, 22, 23, 24, 25, 42, 65, 66, 67, 76, 77, 86, 87, 92, 102, 103, 107], "ha": [4, 7, 10, 16, 28, 47, 54, 60, 65, 66, 67, 73, 77, 91, 92, 100, 106, 107], "size": [4, 8, 10, 11, 12, 13, 14, 19, 20, 22, 25, 47, 49, 60, 77, 100, 103], "unbatch": 4, "minibatch": [4, 47], "d_1": [4, 47], "d_2": [4, 47], "d_k": [4, 47], "k": [4, 11, 12, 47, 54, 73, 77, 106], "geq": [4, 47], "dimension": [4, 8, 10, 18, 47, 54, 77], "case": [4, 9, 11, 12, 13, 14, 17, 18, 19, 20, 25, 39, 46, 47, 50, 52, 53, 54, 61, 63, 65, 66, 67, 73, 77, 87, 92, 95, 97, 103, 106, 107], "last": [4, 6, 10, 11, 12, 51, 59, 73, 77, 85], "being": [4, 8, 47, 58, 65, 66, 67, 77, 85, 94, 100, 101], "higher": [4, 60, 77], "dimens": [4, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 51, 52, 53, 54, 59, 63, 77, 100], "per": [4, 10, 16, 47, 65, 66, 67, 71, 77, 100], "pixel": 4, "2d": [4, 77], "imag": [4, 10, 77], "criterion": 4, "either": [4, 47, 73, 77, 100], "indic": [4, 8, 24, 25, 47, 49, 60, 65, 66, 67, 77, 100, 103], "rang": [4, 18, 25, 63, 77, 85, 89, 90, 92, 101, 103, 107], "where": [4, 8, 9, 10, 11, 12, 17, 18, 47, 49, 50, 52, 53, 54, 60, 61, 62, 65, 66, 67, 77, 89, 91, 92, 101, 107], "number": [4, 9, 11, 12, 13, 17, 18, 19, 20, 22, 33, 41, 42, 46, 47, 49, 51, 52, 53, 59, 77, 107], "specifi": [4, 8, 47, 49, 60, 63, 65, 66, 67, 73, 77, 87, 91, 100, 107], "accept": [4, 6, 77, 100], "necessarili": [4, 77], "unreduc": 4, "e": [4, 8, 10, 49, 54, 60, 73, 77, 85, 86, 87, 91, 100], "describ": [4, 7, 10, 16, 40, 47, 77, 82], "ell": 4, "x": [4, 9, 10, 16, 17, 24, 38, 42, 50, 61, 62, 76, 77, 92, 100, 102, 107], "y": [4, 10, 16, 52, 53, 77, 100], "l": [4, 60], "l_1": 4, "dot": [4, 54, 60, 77, 104], "l_n": 4, "top": [4, 92], "quad": 4, "w_": 4, "y_n": 4, "log": [4, 77, 106, 107], "frac": [4, 7, 10, 11, 12, 16, 18, 60, 63, 77], "exp": [4, 18, 63, 77], "x_": [4, 18, 63], "n": [4, 8, 10, 18, 47, 54, 60, 77, 101, 106], "sum_": 4, "cdot": [4, 106], "mathbb": 4, "text": [4, 8, 9, 10, 11, 12, 17, 18, 22, 46, 47, 50, 61, 62, 63, 65, 66, 67, 77], "ignor": [4, 47, 76, 77, 107], "_index": 4, "w": [4, 8, 10, 49, 77, 100], "span": [4, 77], "default": [4, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 24, 25, 39, 40, 46, 47, 48, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 63, 65, 66, 67, 71, 73, 77, 82, 87, 92, 94, 95, 96, 97, 100, 102, 103, 107], "begin": [4, 47, 65, 66, 67, 73, 77, 107], "end": [4, 5, 6, 19, 47, 65, 66, 67, 73, 77, 107], "logsoftmax": 4, "nllloss": 4, "probabl": [4, 7, 13, 19, 20, 47, 48, 60, 77], "label": [4, 22, 25, 103], "beyond": [4, 65, 66, 107], "singl": [4, 6, 10, 57, 65, 66, 67, 77, 85, 106], "item": [4, 73, 77], "blend": 4, "smooth": [4, 47], "etc": 4, "w_c": 4, "y_": 4, "perform": [4, 6, 60, 63, 65, 66, 67, 71, 77, 100, 107], "better": [4, 60, 92, 106], "allow": [4, 6, 58, 77, 85, 106, 107], "optim": [4, 60, 85, 94, 95, 104], "onli": [4, 10, 25, 31, 32, 33, 34, 42, 47, 54, 60, 68, 76, 77, 80, 81, 85, 86, 89, 90, 91, 94, 100, 101, 103, 105, 107], "too": [4, 60, 77, 89, 90], "restrict": [4, 54], "paramet": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 39, 40, 46, 47, 48, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 71, 80, 81, 82, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 101, 102, 103, 104], "manual": [4, 6, 47, 65, 66, 67, 77, 85, 92, 107], "rescal": [4, 18, 47], "given": [4, 5, 6, 8, 19, 22, 40, 42, 44, 47, 49, 60, 64, 73, 76, 77, 82, 86, 87, 88, 89, 90, 96, 100, 102, 107], "point": [4, 24, 25, 42, 49, 60, 77, 86, 87, 91, 103, 107], "dtype": [4, 8, 10, 11, 12, 47, 52, 53, 54, 60, 63, 77, 100, 101], "deprec": [4, 47, 77], "By": [4, 47, 77, 92, 95, 97, 107], "averag": [4, 40, 47, 65, 66, 82], "over": [4, 6, 10, 16, 47, 56, 57, 58, 60, 65, 66, 67, 77, 92, 100, 106, 107], "element": [4, 7, 10, 16, 18, 39, 47, 48, 50, 55, 60, 63, 73, 77, 106], "batch": [4, 8, 10, 22, 25, 47, 49, 54, 65, 66, 67, 77, 103], "multipl": [4, 47, 54, 77, 80, 81, 87, 107], "sampl": [4, 7, 8, 11, 12, 22, 47, 48, 49, 77], "field": [4, 47, 75, 77], "fals": [4, 7, 8, 10, 11, 12, 16, 17, 24, 38, 47, 48, 49, 60, 61, 65, 66, 67, 71, 73, 77, 100, 102], "instead": [4, 47, 65, 66, 67, 77], "true": [4, 7, 8, 10, 11, 12, 16, 24, 25, 47, 48, 49, 60, 65, 66, 67, 71, 73, 77, 89, 90, 96, 100, 101, 102, 103], "valu": [4, 6, 8, 10, 11, 12, 16, 18, 24, 25, 37, 42, 47, 56, 57, 58, 60, 65, 66, 67, 71, 73, 77, 86, 89, 90, 92, 100, 101, 103, 106, 107], "doe": [4, 5, 6, 8, 47, 54, 60, 65, 66, 67, 71, 73, 77, 85, 94, 107], "contribut": [4, 8, 40, 47, 49, 56, 57, 58, 82], "gradient": [4, 8, 9, 11, 12, 13, 14, 17, 18, 19, 20, 31, 32, 33, 34, 39, 46, 47, 49, 50, 52, 53, 54, 58, 61, 63, 65, 66, 67, 77, 84, 89, 92, 100, 106, 107], "non": [4, 8, 13, 20, 41, 42, 47, 49, 54, 58, 60, 71, 77, 87, 100, 107], "applic": [4, 47, 77, 100], "observ": [4, 47, 77], "depend": [4, 24, 47, 54, 60, 71, 77, 92, 101, 106, 107], "return": [4, 6, 15, 18, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 39, 40, 44, 47, 54, 56, 58, 60, 62, 63, 65, 66, 67, 71, 73, 74, 76, 77, 80, 81, 82, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 100, 101, 102, 103, 107], "taken": [4, 107], "process": [4, 8, 47, 65, 66, 67, 86, 92, 101], "meantim": [4, 47], "two": [4, 31, 47, 54, 58, 60, 65, 66, 67, 77, 85, 100, 107], "arg": [4, 6, 25, 39, 47, 65, 66, 67, 75, 77, 96, 100, 101, 103], "overrid": [4, 47, 100], "amount": [4, 47], "becom": [4, 6, 47, 77], "mixtur": [4, 47], "origin": [4, 17, 47, 61, 77, 81, 101, 106], "ground": [4, 47], "truth": [4, 47], "uniform": [4, 47, 77], "distribut": [4, 7, 9, 47, 48, 50, 77, 105, 106, 107], "rethink": [4, 47], "incept": [4, 47], "vision": [4, 47], "multipli": [4, 9, 13, 17, 18, 47, 50, 54, 60, 61, 62, 63, 77, 106, 107], "chang": [4, 5, 6, 8, 9, 13, 17, 18, 24, 47, 50, 60, 61, 62, 63, 65, 66, 67, 77, 85, 87, 106, 107], "shape": [4, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 46, 47, 49, 50, 52, 53, 60, 61, 62, 63, 77, 100], "nonlinear": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 106], "typic": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77, 95, 97, 107], "high": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77], "correspond": [4, 8, 9, 13, 17, 18, 25, 28, 44, 47, 49, 50, 60, 61, 62, 63, 65, 66, 67, 73, 77, 92, 100, 103], "sharper": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63], "temperatur": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 106], "flatter": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63], "same": [4, 6, 7, 9, 10, 11, 12, 17, 18, 40, 47, 54, 60, 62, 77, 82, 87, 89, 90, 100, 106, 107], "otherwis": [4, 65, 66, 67, 73, 77, 100], "scalar": [4, 10, 54, 77, 80, 81, 99, 107], "align": [4, 47, 60, 65, 66, 67, 77], "randn": [4, 7, 9, 10, 11, 12, 17, 18, 46, 47, 77, 102, 107], "3": [4, 8, 9, 10, 18, 46, 47, 49, 50, 65, 66, 67, 77, 107], "5": [4, 6, 7, 8, 9, 10, 16, 37, 46, 47, 48, 49, 50, 52, 60, 65, 66, 67, 77, 104, 107], "requires_grad": [4, 8, 47, 77, 100], "empti": [4, 77], "long": [4, 77], "random_": [4, 77], "backward": [4, 11, 12, 23, 24, 25, 42, 47, 54, 58, 65, 66, 67, 77, 79, 80, 84, 86, 87, 89, 90, 91, 92, 99, 100, 101, 102, 103, 106, 107], "softmax": [4, 13, 19, 20, 47, 60, 77, 104, 106, 107], "dim": [4, 18, 38, 47, 60, 63, 77], "iter": [5, 65, 66, 67, 71, 73, 77, 88, 94], "modulelist": [5, 6], "automat": [5, 6, 24, 25, 60, 77, 91, 92, 95, 100, 102, 103, 107], "configur": [5, 6], "sake": [5, 6, 83, 91], "track": [5, 6, 24, 25, 77, 86, 91, 103], "caus": [5, 6, 77, 107], "after": [5, 6, 28, 54, 65, 66, 67, 77, 92, 100], "initi": [5, 6, 8, 10, 11, 12, 16, 40, 49, 65, 66, 67, 77, 82, 107], "construct": [5, 6, 8, 15, 49, 74, 77], "like": [5, 6, 13, 19, 20, 65, 66, 67, 73, 77, 100, 107], "regular": [5, 7, 66, 71, 92], "list": [5, 6, 8, 10, 49, 65, 66, 67, 77, 94, 101, 105], "properli": [5, 106], "regist": [5, 6, 65, 66, 67, 77], "visibl": 5, "method": [5, 6, 25, 65, 71, 73, 77, 85, 86, 87, 89, 90, 91, 94, 100, 101, 102, 103, 104, 107], "add": [5, 56, 65, 66, 67, 77, 92, 104, 105, 107], "append": [5, 6, 19, 54, 77, 107], "extend": [5, 72, 100], "self": [5, 13, 17, 20, 60, 61, 65, 66, 67, 76, 77, 101, 102, 104, 107], "insert": [5, 73, 89, 90], "befor": [5, 63, 65, 66, 67, 77, 89, 90, 100, 101, 106, 107], "sequenti": [6, 19], "order": [6, 58, 60, 65, 66, 67, 73, 77, 101, 107], "pass": [6, 8, 23, 24, 25, 42, 58, 60, 65, 66, 67, 71, 76, 77, 79, 80, 81, 84, 89, 90, 91, 92, 95, 99, 100, 101, 102, 103, 106, 107], "constructor": 6, "altern": [6, 40, 64, 82, 107], "ordereddict": 6, "forward": [6, 7, 23, 24, 25, 42, 60, 77, 79, 80, 81, 84, 85, 89, 90, 91, 92, 94, 95, 99, 100, 101, 102, 103, 106, 107], "chain": [6, 77, 91], "final": [6, 12, 53, 54, 91, 92], "call": [6, 7, 23, 24, 25, 60, 65, 66, 67, 77, 85, 86, 87, 89, 90, 91, 92, 94, 95, 96, 97, 100, 101, 102, 103], "sequenc": [6, 13, 20, 22, 25, 51, 77, 103, 106], "treat": [6, 10, 18, 77], "whole": [6, 106], "store": [6, 8, 77, 101], "submodul": 6, "what": [6, 60, 65, 66, 67, 86, 100, 104, 106], "": [6, 14, 31, 32, 33, 34, 52, 53, 54, 60, 65, 66, 67, 73, 77, 91, 100, 101, 102, 106, 107], "exactli": [6, 107], "sound": 6, "On": [6, 11, 12, 54, 65, 66, 67], "other": [6, 21, 46, 54, 60, 65, 66, 67, 77, 91, 105, 107], "hand": 6, "layer": [6, 10, 11, 12, 13, 14, 16, 19, 20, 40, 51, 58, 65, 66, 67, 82, 85, 92, 107], "connect": [6, 20, 56, 57, 58, 92, 106, 107], "cascad": 6, "wai": [6, 77, 85, 92, 100, 107], "creat": [6, 8, 71, 73, 77], "small": [6, 106, 107], "conv2d": 6, "20": [6, 7, 10, 11, 12, 46, 77], "relu": [6, 102], "64": [6, 60, 107], "second": [6, 8, 54, 65, 66, 100], "abov": [6, 60, 77, 106, 107], "conv1": 6, "relu1": 6, "conv2": 6, "relu2": 6, "p": [7, 8, 15, 48, 49, 54, 60, 74, 77], "inplac": [7, 17, 48, 61, 65, 66, 67, 100, 102], "zero": [7, 8, 10, 15, 48, 49, 60, 65, 66, 67, 74, 77, 92, 100], "chosen": [7, 31, 32, 33, 34, 60], "independ": [7, 71, 77, 92], "bernoulli": [7, 48, 77], "channel": [7, 10, 77], "everi": [7, 18, 40, 77, 82, 100, 101], "proven": 7, "effect": [7, 77, 89, 90, 106, 107], "techniqu": [7, 107], "prevent": [7, 63, 77, 100], "co": [7, 77], "adapt": [7, 107], "neuron": 7, "improv": [7, 60, 65, 66, 67], "neural": [7, 17, 61], "network": [7, 12, 17, 53, 61, 65, 66, 67], "detector": 7, "_": [7, 60, 65, 66, 67, 100], "furthermor": [7, 77], "factor": [7, 14, 31, 32, 33, 34, 58, 60, 67, 68, 69, 70, 71, 79, 80, 81, 92, 106, 107], "dure": [7, 8, 24, 25, 49, 60, 77, 100, 103, 107], "evalu": [7, 10, 60, 101, 107], "simpli": [7, 107], "ident": [7, 107], "oper": [7, 10, 16, 24, 25, 26, 48, 52, 53, 54, 60, 63, 77, 79, 87, 92, 100, 102, 103, 105, 106, 107], "place": [7, 48, 49, 77, 100, 107], "16": [7, 77, 87, 107], "num_embed": 8, "embedding_dim": [8, 10, 49], "padding_idx": [8, 49], "max_norm": [8, 49], "norm_typ": [8, 49], "scale_grad_by_freq": [8, 49], "spars": [8, 18, 49, 52, 53, 54, 77], "_weight": 8, "_freez": 8, "devic": [8, 10, 11, 12, 52, 53, 54, 60, 77, 101], "lookup": [8, 49], "tabl": [8, 49], "look": [8, 49, 54, 65, 66, 67, 77, 101, 106, 107], "up": [8, 49, 77, 100, 101, 106, 107], "fix": [8, 49, 77, 106, 107], "dictionari": [8, 65, 66, 67, 73, 77, 92], "often": [8, 49, 77, 107], "word": [8, 49], "retriev": [8, 28, 49, 77, 100, 101], "them": [8, 18, 63, 77, 92, 100, 101, 105], "vector": [8, 49, 54, 77], "entri": [8, 49, 65, 66, 67, 77], "therefor": [8, 49, 77], "updat": [8, 49, 71, 73, 77, 107], "remain": [8, 49], "pad": [8, 13, 20, 49, 77], "newli": 8, "anoth": [8, 77], "norm": [8, 11, 12, 15, 49, 74, 77], "larger": [8, 49, 56, 57, 58, 77, 92, 107], "than": [8, 49, 58, 60, 65, 66, 67, 77, 85, 91, 92, 101], "renorm": [8, 49, 77], "invers": [8, 49, 77, 106], "frequenc": [8, 49], "mini": [8, 10, 49], "matrix": [8, 49, 54, 60, 77], "regard": [8, 49, 65, 66, 107], "learnabl": [8, 10, 11, 12, 16], "mathcal": [8, 11, 12], "type": [8, 11, 12, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 37, 39, 40, 46, 49, 56, 58, 60, 62, 63, 65, 66, 67, 71, 76, 77, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 101, 102, 103], "inttensor": [8, 77], "longtensor": [8, 49, 77], "arbitrari": [8, 23, 24, 25, 49, 60, 77, 92, 99, 103, 107], "extract": [8, 49, 101], "h": [8, 10, 77], "_dim": 8, "10": [8, 10, 15, 37, 46, 49, 74, 77, 102, 107], "4": [8, 14, 46, 49, 77, 100, 102, 107], "9": [8, 49, 65, 66, 67, 77], "xdoctest": [8, 49, 67, 100], "ignore_w": [8, 49], "determinist": [8, 49, 60, 77], "0251": 8, "6902": 8, "7172": 8, "6431": 8, "0748": 8, "6969": 8, "4970": 8, "3448": 8, "9685": 8, "3677": 8, "7265": 8, "1685": 8, "4362": 8, "4004": 8, "9400": 8, "9124": 8, "3616": 8, "1151": 8, "0000": [8, 49, 77], "1535": 8, "0309": 8, "9315": 8, "1655": 8, "9897": 8, "0635": 8, "7895": 8, "7089": 8, "0364": 8, "6778": 8, "5803": 8, "2678": 8, "no_grad": [8, 65, 66, 67, 77], "ones": [8, 10, 16, 49, 60, 77], "classmethod": 8, "from_pretrain": 8, "freez": 8, "instanc": [8, 10, 65, 66, 67, 77, 91], "floattensor": [8, 77], "get": [8, 71, 73, 77, 105, 107], "learn": [8, 10, 11, 12, 17, 61, 65, 66, 67, 71, 106, 107], "pretrain": 8, "6": [8, 46, 77, 102], "1000": [8, 37], "3000": 8, "constraint": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 101, 104, 107], "to_output_scal": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104], "approxim": [9, 17, 50, 61, 107], "gaussian": [9, 17, 50, 61], "error": [9, 17, 24, 25, 50, 55, 60, 61, 77, 100, 101, 103], "linear": [9, 12, 17, 50, 53, 61, 62, 77, 89, 90, 102, 104, 107], "phi": [9, 50], "cumul": [9, 50], "tanh": [9, 50, 77], "estim": [9, 10, 50], "sqrt": [9, 10, 11, 12, 16, 38, 50, 60, 65, 66, 77], "pi": [9, 50, 77], "044715": [9, 50], "algorithm": [9, 60, 65, 66], "name": [9, 11, 12, 13, 14, 17, 18, 19, 20, 22, 25, 28, 39, 46, 50, 52, 53, 54, 61, 63, 77, 101, 103], "In": [9, 11, 12, 13, 14, 17, 18, 19, 20, 24, 39, 46, 50, 52, 53, 54, 58, 60, 61, 63, 73, 77, 92, 95, 100, 106, 107], "must": [9, 11, 12, 13, 14, 17, 18, 19, 20, 28, 37, 39, 46, 50, 52, 53, 54, 60, 61, 63, 77, 85, 86, 87, 100, 106, 107], "one": [9, 11, 12, 13, 14, 17, 18, 19, 20, 28, 39, 46, 50, 52, 53, 54, 58, 60, 61, 63, 65, 66, 67, 77, 85, 100, 107], "gmean": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104, 107], "hmean": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104], "amean": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104], "to_grad_input_scal": [9, 11, 12, 14, 17, 18, 39, 50, 52, 53, 61, 63, 104], "normalized_shap": [10, 16, 51, 59], "ep": [10, 16, 38, 51, 59, 65, 66], "1e": [10, 16, 51, 59, 65, 66, 67, 77, 107], "05": [10, 16, 51, 59, 77, 87, 107], "elementwise_affin": [10, 16], "bia": [10, 11, 12, 15, 51, 52, 53, 60, 74, 102, 107], "normal": [10, 16, 42, 51, 59, 77, 106, 107], "mathrm": [10, 38, 65, 66, 106], "var": [10, 77], "epsilon": [10, 16, 38, 65, 66], "gamma": [10, 16, 65, 66, 67], "deviat": [10, 77, 99, 100, 101, 102, 107], "calcul": [10, 40, 68, 69, 70, 82], "d": [10, 52, 53, 73, 77, 100, 102, 107], "affin": [10, 77], "via": [10, 77, 86, 87, 91, 92, 101, 107], "bias": [10, 92], "unbias": 10, "unlik": [10, 24, 25, 77, 103, 107], "entir": [10, 77, 107], "plane": 10, "statist": 10, "data": [10, 11, 12, 15, 22, 52, 53, 63, 74, 77, 84, 101], "both": [10, 40, 54, 60, 77, 82, 100, 102, 107], "mode": [10, 60, 77, 100], "_shape": 10, "ldot": [10, 65, 66, 67, 77], "integ": [10, 46, 77, 100], "singleton": [10, 77], "specif": [10, 60, 65, 66, 67, 77, 89, 90, 101], "denomin": [10, 16, 65, 66], "numer": [10, 16, 24, 25, 49, 65, 66, 83, 89, 90, 103, 107], "stabil": [10, 16, 65, 66], "boolean": [10, 16, 60, 77, 100], "addit": [10, 11, 12, 18, 46, 52, 53, 65, 66, 67, 77, 91, 107], "nlp": 10, "sentence_length": 10, "embed": [10, 19, 40, 82, 104], "layer_norm": [10, 104], "three": [10, 32, 34, 60, 106], "spatial": 10, "shown": [10, 24, 107], "in_featur": [11, 12], "out_featur": [11, 12], "weight_mup_typ": [11, 12], "liter": [11, 12, 15, 74, 77], "incom": [11, 12, 52, 53], "tensorfloat32": [11, 12, 52, 53, 54], "certain": [11, 12, 51, 54, 59, 65, 66, 67, 87, 101], "rocm": [11, 12, 54], "float16": [11, 12, 54, 60, 65, 66, 67, 77], "_featur": [11, 12, 52, 53], "h_": [11, 12], "includ": [11, 12, 52, 53, 77, 105, 106], "30": [11, 12, 77], "128": [11, 12, 60, 77], "print": [11, 12, 102, 107], "appropri": [12, 71, 77], "hidden_s": [13, 14, 19, 20, 102], "head": [13, 19, 20, 106], "is_caus": [13, 20, 60], "dropout_p": [13, 19, 20, 60], "multi": [13, 20], "attent": [13, 19, 20, 22, 40, 60, 82, 104, 107], "warn": [13, 19, 20, 25, 60, 103], "here": [13, 19, 20, 60, 77, 92, 94, 105, 107], "give": [13, 19, 20, 77, 92, 104, 106, 107], "incorrect": [13, 19, 20, 77, 100], "hidden": [13, 14, 19, 20], "causal": [13, 20, 60], "mask": [13, 19, 20, 22, 60, 77], "post": [13, 19, 20, 65, 66, 67, 77, 106], "dropout": [13, 19, 20, 60, 104], "expansion_factor": 14, "swiglu": 14, "intermedi": [14, 65, 66, 67, 91, 101], "increas": [14, 60, 100], "rel": [14, 20, 40, 56, 58, 77, 82, 87, 106], "mup_typ": [15, 71, 74], "mup_scaling_depth": [15, 71, 74], "\u03bcp": [15, 40, 72, 74, 75, 82], "object": [15, 65, 66, 67, 73, 74, 75, 77, 84, 91, 100], "annot": [15, 23, 74, 102], "parameterdata": [15, 68, 69, 70, 71, 74, 78], "protocol": [15, 74, 75, 78], "assert": [15, 60, 74], "tupl": [16, 22, 28, 38, 43, 44, 52, 58, 59, 65, 66, 67, 77, 96, 100, 101, 102, 107], "rm": [16, 59, 104], "normalis": [16, 58], "trail": 16, "root": 16, "squar": [16, 55, 60, 65, 66, 77], "sigmoid": [17, 61, 62, 77], "known": [17, 61, 105], "swish": [17, 61], "sigma": [17, 61, 62, 77], "logist": [17, 61, 62], "gelu": [17, 61, 92, 104, 107], "wa": [17, 61, 77, 101], "coin": [17, 61], "reinforc": [17, 61], "gate": [17, 61, 62], "experi": [17, 61, 107], "later": [17, 61, 77], "lie": [18, 63], "li": 18, "adjust": 18, "accordingli": 18, "defin": [18, 28, 60, 63, 64, 65, 66, 67, 76, 77, 92, 100, 107], "x_i": [18, 63], "sum_j": [18, 63], "x_j": [18, 63], "unspecifi": [18, 65, 66, 67, 77], "inf": [18, 60, 77], "along": [18, 22, 63, 65, 66, 67, 77, 107], "slice": [18, 63, 77, 87, 101], "vocab_s": 19, "residual_sc": 19, "callabl": [19, 39, 40, 57, 65, 66, 67, 71, 77, 82, 88, 92, 94, 95, 96, 101, 102, 107], "transformer_residual_scaling_rul": [19, 104], "local": [19, 102], "_tau": 19, "decod": 19, "just": [19, 22, 65, 66, 67, 77, 100, 101, 106, 107], "lack": [19, 73], "kei": [19, 60, 73, 77, 84, 91, 92, 104], "usag": [19, 35, 77], "infer": [19, 77], "token": [19, 22, 25, 103], "vocabulari": 19, "residu": [19, 20, 40, 56, 57, 58, 82, 92, 106, 107], "scheme": [19, 40, 77, 82], "control": [19, 60, 64, 107], "trunk": 19, "core": [19, 104, 106], "mhsa_tau": 20, "mlp_tau": 20, "prenorm": 20, "arxiv": 20, "org": 20, "ab": [20, 77], "2002": 20, "04745": 20, "branch": [20, 56, 57, 58, 107], "skip": [20, 56, 57, 58, 65, 66, 67, 100, 106, 107], "mlp": [20, 40, 82, 102, 104, 107], "tool": [21, 107], "analys": [21, 102, 107], "metric": [21, 23, 24, 91, 104], "within": [21, 77, 85, 94, 101, 107], "pretrainedtokenizerbas": [22, 25, 103], "batch_siz": [22, 25, 100, 103, 107], "seq_len": [22, 25, 103], "dataset_path": [22, 25, 103], "wikitext": [22, 25, 103], "dataset_nam": [22, 25, 103], "103": [22, 25, 103], "v1": [22, 25, 103], "shuffle_buffer_s": 22, "10000": 22, "seed": 22, "1472": 22, "dataset": [22, 25, 103], "shift": [22, 77], "length": [22, 25, 77, 103, 106], "huggingfac": [22, 25, 103], "visualis": [22, 24, 91, 104], "random": [22, 60, 77, 100], "chunk": [22, 77], "determin": [22, 54, 107], "10_000": 22, "shuffl": 22, "input_idx": 22, "attn_mask": [22, 25, 60, 103], "g": [23, 24, 40, 62, 65, 66, 77, 82, 86, 87, 91, 100, 106], "graph": [23, 24, 25, 65, 66, 77, 86, 87, 88, 91, 92, 94, 95, 96, 97, 100, 101, 103], "datafram": 23, "convert": [23, 43, 77, 97, 107], "fx": [23, 24, 85, 86, 87, 88, 91, 94, 95, 97, 101, 102, 107], "panda": 23, "indend": 23, "been": [23, 24, 28, 77, 86, 87, 92, 100, 107], "track_scal": [23, 24, 25, 86, 87, 103, 104], "possibli": [23, 24], "scales_graph": [23, 24, 25, 86, 87, 91, 103], "result": [23, 24, 56, 58, 77, 91, 101], "inform": [23, 60, 77, 84, 87, 91], "intern": [23, 89, 90, 91], "plot": [23, 25, 91, 103, 104], "pd": 23, "titl": 24, "mean_ab": [24, 84], "prune_same_scal": 24, "show_arrow": 24, "show_error_bar": 24, "show_zero_tensor": 24, "xmin": 24, "xmax": 24, "ax": [24, 25, 103], "matplotlib": [24, 25, 103], "intend": [24, 65, 66, 67, 87, 91, 99, 100], "inpt": [24, 86, 87, 91], "prune": [24, 25, 86, 87, 88, 103], "deem": [24, 25, 87, 103], "perspect": [24, 25, 103, 107], "faint": [24, 25, 103], "colour": [24, 25, 103], "horizont": [24, 25, 103], "line": [24, 25, 103, 107], "row": [24, 25, 49, 77, 103], "repres": [24, 25, 77, 84, 88, 91, 92, 94, 96, 99, 100, 101, 102, 103, 107], "bar": [24, 25, 103], "maximum": [24, 25, 42, 49, 77, 103], "minimum": [24, 25, 42, 77, 103], "seen": [24, 25, 77, 103, 106, 107], "axi": [24, 77], "abs_mean": [24, 84], "std": [24, 77, 84], "abs_max": [24, 84], "abs_min": [24, 84], "numel": [24, 77, 84], "reshap": [24, 77, 87], "clearer": 24, "arrow": 24, "denot": [24, 77, 107], "max": [24, 65, 66, 77, 107], "min": [24, 77, 107], "displai": 24, "plot_kwarg": [25, 103], "experiment": [25, 77, 89, 90, 92, 103], "conveni": [25, 94, 103], "combin": [25, 52, 53, 54, 56, 57, 77, 85, 103, 107], "example_batch": [25, 103, 104], "wide": [25, 92, 103], "interfac": [25, 92, 94, 103], "futur": [25, 77, 85, 100, 103], "now": [25, 77, 91, 95, 100, 103, 107], "templat": [25, 103], "tracked_model": [25, 103], "keyword": [25, 60, 77, 101, 103], "alias": 26, "arithmet": 27, "group": [27, 28, 29, 30, 64, 65, 66, 67, 71, 77], "constrain": [27, 29, 30, 92, 107], "constraint_nam": 28, "rais": [28, 60, 73, 77, 100], "valueerror": 28, "geometr": [29, 77, 107], "harmon": 30, "xavier": [30, 107], "glorot": [30, 107], "output_scal": [31, 32, 33, 34, 39, 107], "grad_input_scal": [31, 33, 39, 107], "select": [31, 32, 33, 34, 60, 77, 106], "op": [31, 32, 33, 34, 65, 66, 67, 77, 85, 89, 90, 100, 101, 105, 107], "equal": [31, 32, 33, 34, 40, 49, 56, 57, 58, 77, 82, 107], "left_grad_scal": [32, 34], "right_grad_scal": [32, 34], "left": [32, 34, 54, 60, 77, 107], "right": [32, 34, 54], "grad": [33, 39, 52, 65, 66, 67, 77, 80, 81, 100, 107], "compon": 35, "advanc": 35, "alpha": [37, 46, 77], "lower": [37, 60, 65, 66, 67], "upper": [37, 60, 77], "interpol": [37, 77], "logarithm": 37, "space": 37, "constant": 37, "ratio": [37, 40, 56, 57, 58, 77, 82, 107], "limit": [37, 60, 91, 104], "keepdim": [38, 77], "wise": [39, 50, 55], "take": [39, 60, 77, 92, 107], "kwarg": [39, 65, 66, 67, 75, 77, 96, 100, 101], "residual_mult": [40, 82], "residual_attn_ratio": [40, 82], "tau": [40, 56, 57, 58, 67, 82, 107], "rule": [40, 60, 64, 65, 66, 67, 77, 82, 106], "stack": [40, 82, 107], "start": [40, 77, 82, 101, 107], "ensur": [40, 60, 77, 82, 100, 101, 106, 107], "varianc": [40, 82, 106], "attn": [40, 82], "hyperparamet": [40, 82, 107], "total": [40, 77, 82], "appendix": [40, 82], "ffn": [40, 82], "fn": [40, 57, 82, 95], "simul": [41, 77, 89, 90], "exponent_bit": [42, 43, 44], "mantissa_bit": [42, 43, 44], "round": [42, 77, 89, 90], "stochast": [42, 65, 89, 90], "srbit": 42, "represent": [42, 77, 89, 90], "properti": [42, 77, 106], "bit": [42, 77], "max_absolute_valu": 42, "absolut": [42, 77], "min_absolute_norm": 42, "min_absolute_subnorm": 42, "subnorm": 42, "quantis": [42, 89, 90], "differenti": [42, 65, 66, 67, 77, 100], "quantise_bwd": 42, "quantise_fwd": 42, "fpformat": [43, 44, 89, 104], "_i": 46, "broadcast": [46, 54, 60, 77], "promot": 46, "complex": [46, 77, 101], "to_left_grad_scal": [46, 54, 104], "to_right_grad_scal": [46, 54, 104], "0202": 46, "0985": 46, "3506": 46, "6056": 46, "21": 46, "19": 46, "3944": 46, "b": [46, 52, 53, 67, 77, 100], "9732": 46, "3497": 46, "6245": 46, "4022": 46, "3743": 46, "7724": 46, "5811": 46, "8017": 46, "7695": 46, "3930": 46, "3672": 46, "1450": 46, "18": [46, 77], "6971": 46, "0736": 46, "17": 46, "0994": 46, "3216": 46, "7845": 46, "1610": 46, "1868": 46, "4090": 46, "8": [46, 60, 65, 66, 77, 107], "9902": 46, "3667": 46, "7": [46, 77], "3925": 46, "6147": 46, "crossentropyloss": [47, 104, 107], "predict": [47, 77], "section": [47, 77, 104, 107], "divid": [47, 77], "randint": [47, 77], "int64": [47, 77], "dictionaryand": 49, "analyt": 49, "respect": [49, 65, 66, 67, 77, 100], "column": [49, 77], "modifi": [49, 65, 66, 67, 77, 100, 106], "v": [49, 65, 66, 67, 73, 77, 106], "embedding_matrix": 49, "rand": [49, 60, 77], "8490": 49, "9625": 49, "6753": 49, "9666": 49, "7761": 49, "6108": 49, "6246": 49, "9751": 49, "3618": 49, "4161": 49, "2419": 49, "7383": 49, "0237": 49, "7794": 49, "0528": 49, "3385": 49, "8612": 49, "1867": 49, "zero_": [49, 77], "5609": 49, "5384": 49, "8720": 49, "6262": 49, "2438": 49, "7471": 49, "layernorm": [51, 104], "scale_pow": 52, "xa": [52, 53], "layout": [52, 53, 54, 77, 101], "autograd": [52, 53, 54, 65, 66, 67, 77, 100], "miss": [52, 53, 54, 105], "pleas": [52, 53, 54, 60, 65, 66, 77, 100, 107], "open": [52, 53, 54, 106], "request": [52, 53, 54, 65, 66, 67, 105], "power": 52, "product": [54, 60, 77, 104], "behavior": [54, 65, 66, 67, 77, 100], "prepend": [54, 65, 66, 67, 77], "purpos": [54, 77, 89, 90, 107], "remov": [54, 65, 66, 67, 73, 77, 87, 96, 107], "least": [54, 77], "thu": [54, 107], "j": [54, 77], "logic": 54, "valid": [54, 77, 92, 107], "even": [54, 77, 107], "though": [54, 58, 77, 85, 94, 105, 107], "particular": [54, 77, 92, 106, 107], "mm": [54, 77], "measur": [55, 107], "mseloss": 55, "togeth": [56, 107], "conjunct": [56, 58, 91, 95, 107], "residual_split": [56, 57, 104, 107], "come": [56, 77, 85], "favor": [56, 57, 58], "maintain": 57, "residual_add": [57, 58, 104, 107], "split": [58, 77], "prior": [58, 60, 100], "necessari": [58, 77, 86, 100, 107], "delai": [58, 77, 107], "still": [58, 77, 85, 92, 100, 105], "benefici": 58, "behav": [58, 65, 66, 67, 77], "rmsnorm": [59, 104], "queri": 60, "whatev": [60, 77], "underli": [60, 77, 85], "avail": [60, 91, 107], "flash": 60, "greater": [60, 77], "identifi": [60, 77, 86, 92], "versu": 60, "effici": [60, 77, 100, 107], "def": [60, 76, 77, 92, 100, 102, 107], "scale_factor": 60, "math": [60, 102], "els": [60, 65, 66, 67, 73, 77], "attn_bia": 60, "temp_mask": 60, "tril": [60, 77], "diagon": [60, 77], "masked_fill_": [60, 77], "logical_not": [60, 77], "attn_weight": 60, "transpos": [60, 77], "subject": [60, 77], "alwai": [60, 77, 91, 100], "accord": [60, 77, 92], "disabl": [60, 71], "sure": 60, "mymodel": 60, "__init__": [60, 102, 107], "super": [60, 102, 107], "There": [60, 77, 100, 107], "flashattent": 60, "faster": [60, 107], "parallel": 60, "partit": 60, "memori": [60, 65, 66, 67, 77, 100, 107], "match": [60, 65, 66, 67, 77], "formul": 60, "kernel": [60, 77, 107], "cuda": [60, 65, 66, 67, 77], "backend": [60, 85, 94, 95], "attempt": [60, 107], "most": [60, 77, 100, 105, 106, 107], "fine": [60, 65, 66, 67, 77], "grain": 60, "context": [60, 65, 66, 67, 77, 100], "manag": [60, 77], "prefer": 60, "mechan": [60, 107], "sdpa_kernel": 60, "enable_flash_sdp": 60, "global": [60, 71, 85], "enable_mem_efficient_sdp": 60, "enable_math_sdp": 60, "fuse": [60, 65, 66, 67, 107], "event": [60, 77], "reason": [60, 91, 105, 106], "cannot": [60, 77], "due": [60, 65, 66, 67, 77], "natur": 60, "float64": [60, 65, 66, 67, 77], "numerical_accuraci": 60, "circumst": 60, "cudnn": 60, "nondeterminist": [60, 77], "undesir": 60, "try": [60, 65, 66, 67, 107], "potenti": [60, 77], "cost": [60, 77], "ev": 60, "part": 60, "score": 60, "triangular": 60, "form": [60, 107], "causalbia": 60, "thrown": [60, 77], "32": [60, 77], "sdp_kernel": 60, "enable_math": 60, "silu": [62, 104], "desir": [63, 77, 107], "cast": [63, 77], "overflow": [63, 107], "mup": [64, 65, 66, 67, 71], "adam": [64, 66, 69, 71, 104], "adamw": [64, 69, 104], "sgd": [64, 70, 77, 104], "scaled_paramet": [64, 104], "finer": 64, "downstream": 64, "lr": [64, 65, 66, 67, 68, 69, 70, 71, 77], "param": [65, 66, 67, 68, 69, 70, 71, 77], "dict": [65, 66, 67, 71, 73, 77, 92, 96, 101], "001": [65, 66, 67], "weight_decai": [65, 66, 67, 71], "independent_weight_decai": [65, 66, 67, 71], "allow_non_unit_scaling_param": [65, 66, 67, 71], "110mm": [65, 66, 67], "4pt": [65, 66, 67], "textbf": [65, 66, 67], "beta_1": [65, 66], "beta_2": [65, 66], "theta_0": [65, 66, 67], "theta": [65, 66, 67], "hspace": [65, 66, 67], "13mm": [65, 66, 67], "lambda": [65, 66, 67, 77], "decai": [65, 66, 67, 71], "textit": [65, 66, 67], "amsgrad": [65, 66], "maxim": [65, 66, 67], "m_0": [65, 66], "leftarrow": [65, 66, 67], "moment": [65, 66], "v_0": [65, 66], "widehat": [65, 66], "ex": [65, 66, 67], "5mm": [65, 66, 67], "10mm": [65, 66, 67], "g_t": [65, 66, 67], "nabla_": [65, 66, 67], "f_t": [65, 66, 67], "theta_": [65, 66, 67], "neq": [65, 67], "m_t": [65, 66], "m_": [65, 66], "v_t": [65, 66], "v_": [65, 66], "2_t": [65, 66], "big": [65, 66], "theta_t": [65, 66, 67], "bf": [65, 66, 67], "refer": [65, 66, 77, 104, 107], "rate": [65, 66, 67, 71, 106], "captur": [65, 66, 95, 97], "coeffici": [65, 66], "999": [65, 66, 102, 107], "term": [65, 66, 77, 107], "l2": [65, 67], "penalti": [65, 67], "whether": [65, 66, 67, 77, 92, 100], "variant": [65, 66], "converg": [65, 66], "foreach": [65, 66, 67], "loop": [65, 66, 67], "sinc": [65, 66, 67, 77], "usual": [65, 66, 67, 77, 107], "significantli": [65, 66, 67, 107], "sizeof": [65, 66, 67], "peak": [65, 66, 67], "tensorlist": [65, 66, 67], "prohibit": [65, 66, 67], "fewer": [65, 66, 67], "through": [65, 66, 67, 77, 100, 101, 105, 106, 107], "switch": [65, 66, 67], "flag": [65, 66, 67], "minim": [65, 66, 67], "safe": [65, 66, 77], "impair": [65, 66, 67], "ungraph": [65, 66], "leav": [65, 66, 67, 77, 107], "occur": [65, 66, 67, 77, 107], "step": [65, 66, 67, 71, 77, 107], "float32": [65, 66, 67, 77], "bfloat16": [65, 66, 67, 77], "add_param_group": [65, 66, 67], "param_group": [65, 66, 67, 71], "tune": [65, 66, 67, 107], "frozen": [65, 66, 67], "made": [65, 66, 67, 100], "trainabl": [65, 66, 67], "progress": [65, 66, 67, 77], "load_state_dict": [65, 66, 67, 77], "state_dict": [65, 66, 67], "load": [65, 66, 67, 77], "state": [65, 66, 67, 77], "register_load_state_dict_post_hook": [65, 66, 67], "hook": [65, 66, 67, 77, 100], "removablehandl": [65, 66, 67], "signatur": [65, 66, 67, 76, 77, 100], "fire": [65, 66, 67], "alreadi": [65, 66, 67, 77], "handl": [65, 66, 67, 77, 100, 101], "util": [65, 66, 67, 77, 104, 107], "removeablehandl": [65, 66, 67], "register_load_state_dict_pre_hook": [65, 66, 67], "shallow": [65, 66, 67, 73], "copi": [65, 66, 67, 73, 77, 92], "new": [65, 66, 67, 73, 77, 85, 89, 90, 91, 94, 95, 96, 104, 105, 107], "register_state_dict_post_hook": [65, 66, 67], "register_state_dict_pre_hook": [65, 66, 67], "register_step_post_hook": [65, 66, 67], "register_step_pre_hook": [65, 66, 67], "new_arg": [65, 66, 67], "new_kwarg": [65, 66, 67], "hold": [65, 66, 67, 77], "Its": [65, 66, 67], "content": [65, 66, 67, 100], "characterist": [65, 66, 67], "itself": [65, 66, 67, 77], "NOT": [65, 66, 67], "map": [65, 66, 67, 77, 101], "metadata": [65, 66, 67, 75], "associ": [65, 66, 67, 91], "zip": [65, 66, 67], "actual": [65, 66, 67, 77], "without": [65, 66, 67, 77, 85, 102, 105, 107], "verif": [65, 66, 67], "might": [65, 66, 67, 77, 107], "momentum_buff": [65, 66, 67], "01": [65, 66, 67, 77, 107], "closur": [65, 66, 67], "reevalu": [65, 66, 67], "zero_grad": [65, 66, 67], "set_to_non": [65, 66, 67], "reset": [65, 66, 67], "footprint": [65, 66, 67], "modestli": [65, 66, 67], "howev": [65, 66, 67, 77, 92, 94, 106, 107], "tri": [65, 66, 67, 77], "access": [65, 66, 67, 77, 91, 100], "attribut": [65, 66, 67, 72, 76, 77, 100, 101], "guarante": [65, 66, 67, 77, 101, 105], "did": [65, 66, 67], "receiv": [65, 66, 67, 107], "altogeth": [65, 66, 67], "decoupl": 66, "mu": [67, 77], "momentum": 67, "dampen": 67, "nesterov": 67, "15mm": 67, "_t": 67, "g_": 67, "formula": [67, 100], "deep": [67, 106, 107], "__": 67, "loss_fn": 67, "lr_scale_func": 71, "adam_lr_scale_func": 71, "paramst": 71, "tag": [71, 75], "lr_scale_func_sgd": [71, 104], "overridden": [71, 100], "fail": [71, 77, 107], "rememb": 73, "clear": [73, 101], "od": 73, "fromkei": 73, "move_to_end": 73, "move": [73, 77, 85, 94], "keyerror": 73, "pop": 73, "popitem": 73, "pair": [73, 77, 99], "lifo": 73, "fifo": 73, "setdefault": 73, "extra": [75, 107], "implicitli": 75, "proto": 76, "meth": 76, "Such": 76, "primarili": 76, "static": [76, 100], "checker": 76, "recogn": 76, "structur": [76, 100, 101], "subtyp": 76, "duck": 76, "func": [76, 100], "pep": 76, "544": 76, "decor": [76, 100], "runtime_check": 76, "act": 76, "simpl": [76, 77], "mind": 76, "runtim": 76, "presenc": 76, "genproto": 76, "conjug": 77, "conj": 77, "matric": 77, "real": 77, "mh": 77, "revers": 77, "permut": 77, "throw": 77, "mt": 77, "arang": 77, "ndim": 77, "abs_": 77, "alia": 77, "absolute_": 77, "aco": 77, "acos_": 77, "acosh": 77, "acosh_": 77, "add_": 77, "addbmm": 77, "batch1": 77, "batch2": 77, "addbmm_": 77, "addcdiv": 77, "tensor1": 77, "tensor2": 77, "addcdiv_": 77, "addcmul": 77, "addcmul_": 77, "addmm": 77, "mat1": 77, "mat2": 77, "addmm_": 77, "addmv": 77, "mat": 77, "vec": 77, "addmv_": 77, "addr": 77, "vec1": 77, "vec2": 77, "addr_": 77, "adjoint": 77, "align_a": 77, "explicit": 77, "align_to": 77, "127": 77, "refine_nam": 77, "img": 77, "scale_channel": 77, "num_channel": 77, "more_img": 77, "video": [77, 104], "agnost": 77, "api": [77, 101, 104, 107], "ellipsi": 77, "expand": [77, 100], "mention": 77, "appear": 77, "string": [77, 102], "unment": 77, "named_tensor": 77, "front": 77, "keep": [77, 107], "rest": 77, "allclos": 77, "rtol": [77, 87], "atol": 77, "08": 77, "equal_nan": 77, "amax": 77, "amin": 77, "aminmax": 77, "angl": 77, "apply_": 77, "cpu": 77, "arcco": 77, "arccos_": 77, "arccosh": 77, "arccosh_": 77, "arcsin": 77, "arcsin_": 77, "arcsinh": 77, "arcsinh_": 77, "arctan": 77, "arctan2": 77, "arctan2_": 77, "atan2_": 77, "arctan_": 77, "arctanh": 77, "arctanh_": 77, "argmax": 77, "argmin": 77, "argsort": 77, "descend": [77, 101], "argwher": 77, "as_strid": 77, "stride": 77, "storage_offset": 77, "as_strided_": 77, "as_strided_scatt": 77, "src": 77, "as_subclass": 77, "cl": 77, "pointer": 77, "stai": [77, 107], "attach": 77, "subclass": [77, 100], "asin": 77, "asin_": 77, "asinh": 77, "asinh_": 77, "atan": 77, "atan2": 77, "atan_": 77, "atanh": 77, "atanh_": 77, "retain_graph": 77, "create_graph": 77, "wrt": 77, "addition": 77, "accumul": 77, "stream": 77, "semant": [77, 101], "leaf": 77, "grad_fn": 77, "strictli": 77, "reli": [77, 92], "pull": 77, "60521": 77, "issuecom": 77, "867061780": 77, "omit": 77, "freed": 77, "nearli": 77, "much": [77, 92, 107], "deriv": [77, 106, 107], "were": [77, 100], "baddbmm": 77, "baddbmm_": 77, "texttt": 77, "bernoulli_": 77, "fill": 77, "locat": 77, "integr": 77, "draw": 77, "binari": 77, "th": 77, "_tensor": 77, "memory_format": [77, 101], "preserve_format": 77, "bincount": 77, "minlength": 77, "bitwise_and": 77, "bitwise_and_": 77, "bitwise_left_shift": 77, "bitwise_left_shift_": 77, "bitwise_not": 77, "bitwise_not_": 77, "bitwise_or": 77, "bitwise_or_": 77, "bitwise_right_shift": 77, "bitwise_right_shift_": 77, "bitwise_xor": 77, "bitwise_xor_": 77, "bmm": 77, "broadcast_to": 77, "byte": 77, "uint8": 77, "cauchy_": 77, "median": 77, "drawn": 77, "cauchi": 77, "dfrac": 77, "cdoubl": 77, "complex128": 77, "ceil": 77, "ceil_": 77, "cfloat": 77, "complex64": 77, "chalf": 77, "complex32": 77, "char": 77, "int8": 77, "choleski": 77, "cholesky_invers": 77, "cholesky_solv": 77, "input2": 77, "clamp": 77, "clamp_": 77, "clip": [77, 107], "clip_": 77, "clone": [77, 100, 104, 107], "coalesc": 77, "uncoalesc": 77, "coo": 77, "col_indic": 77, "csr": 77, "sparse_csr": 77, "nnz": 77, "int32": 77, "mkl": 77, "routin": 77, "avoid": 77, "downcast": 77, "lose": 77, "ey": 77, "to_sparse_csr": 77, "conj_phys": 77, "conj_physical_": 77, "contigu": 77, "contiguous_format": 77, "copy_": 77, "non_block": 77, "resid": 77, "gpu": 77, "asynchron": 77, "host": 77, "copysign": 77, "copysign_": 77, "corrcoef": 77, "cos_": 77, "cosh": 77, "cosh_": 77, "count_nonzero": 77, "cov": 77, "correct": [77, 89, 90, 100, 106, 107], "fweight": 77, "aweight": 77, "crow_indic": 77, "compress": 77, "destin": 77, "pin": 77, "cummax": 77, "cummin": 77, "cumprod": 77, "cumprod_": 77, "cumsum": 77, "cumsum_": 77, "data_ptr": 77, "address": [77, 85, 107], "deg2rad": 77, "deg2rad_": 77, "dense_dim": 77, "dens": 77, "len": 77, "sparse_dim": 77, "hybrid": 77, "dequant": 77, "quantiz": 77, "det": 77, "detach": 77, "never": [77, 87], "affect": 77, "share": [77, 100, 107], "storag": [77, 100], "trigger": 77, "detach_": 77, "diag": 77, "diag_emb": 77, "offset": 77, "dim1": 77, "dim2": 77, "diagflat": 77, "diagonal_scatt": 77, "diff": 77, "digamma": 77, "digamma_": 77, "dim_ord": 77, "physic": 77, "laid": 77, "outermost": 77, "innermost": 77, "channels_last": 77, "dist": 77, "div": 77, "rounding_mod": 77, "div_": 77, "divide_": 77, "doubl": [77, 100], "dsplit": 77, "split_size_or_sect": 77, "element_s": 77, "individu": 77, "eq": 77, "eq_": 77, "erf": 77, "erf_": 77, "erfc": 77, "erfc_": 77, "erfinv": 77, "erfinv_": 77, "exp2": 77, "exp2_": 77, "exp_": 77, "alloc": 77, "As": [77, 90, 107], "especi": 77, "write": 77, "expand_a": 77, "expm1": 77, "expm1_": 77, "exponential_": 77, "lambd": 77, "pdf": 77, "densiti": 77, "theori": [77, 107], "exponenti": 77, "interv": 77, "impli": 77, "fill_": 77, "fill_diagonal_": 77, "fill_valu": 77, "wrap": [77, 95, 101, 102, 107], "main": 77, "tall": 77, "fix_": 77, "flatten": 77, "start_dim": 77, "end_dim": 77, "flip": 77, "fliplr": 77, "flipud": 77, "float_pow": 77, "expon": 77, "float_power_": 77, "floor": 77, "floor_": 77, "floor_divid": 77, "floor_divide_": 77, "fmax": 77, "fmin": 77, "fmod": 77, "divisor": 77, "fmod_": 77, "frac_": 77, "frexp": 77, "mantissa": 77, "gather": 77, "gcd": 77, "gcd_": 77, "ge": 77, "ge_": 77, "geometric_": 77, "trial": 77, "success": 77, "henc": [77, 92], "wherea": 77, "geqrf": 77, "ger": 77, "get_devic": 77, "ordin": 77, "greater_": 77, "greater_equ": 77, "greater_equal_": 77, "gt": [77, 106], "gt_": 77, "half": 77, "hardshrink": 77, "has_nam": 77, "heavisid": 77, "heaviside_": 77, "histc": 77, "histogram": 77, "hsplit": 77, "hypot": 77, "hypot_": 77, "i0": 77, "i0_": 77, "igamma": 77, "igamma_": 77, "igammac": 77, "igammac_": 77, "imaginari": 77, "3100": 77, "3553j": 77, "5445": 77, "7896j": 77, "6492": 77, "0633j": 77, "0638": 77, "8119j": 77, "3553": 77, "7896": 77, "0633": 77, "8119": 77, "index_add": 77, "index_add_": [77, 100], "subtract": 77, "index_copi": 77, "index_copy_": 77, "duplic": 77, "index_fil": 77, "index_fill_": 77, "index_put": 77, "index_put_": 77, "put": [77, 107], "express": [77, 92, 96, 106], "undefin": [77, 100], "index_reduce_": 77, "include_self": 77, "prod": 77, "identit": 77, "11": 77, "12": 77, "44": 77, "72": 77, "14": 77, "22": 77, "36": 77, "index_select": 77, "inner": [77, 92], "int_repr": 77, "uint8_t": 77, "is_coalesc": 77, "is_complex": 77, "is_conj": 77, "is_contigu": 77, "is_cpu": 77, "is_cuda": 77, "is_floating_point": 77, "is_infer": 77, "is_ipu": 77, "is_leaf": 77, "convent": [77, 101], "popul": [77, 101], "retain_grad": 77, "engin": [77, 100], "requires_grad_": [77, 91, 102, 107], "is_meta": 77, "meta": [77, 91], "carri": 77, "is_mp": 77, "mp": 77, "is_neg": 77, "neg": 77, "is_pin": 77, "is_quant": 77, "is_set_to": 77, "exact": [77, 107], "is_shar": 77, "is_sign": 77, "sign": [77, 107], "is_spars": 77, "is_sparse_csr": 77, "is_xla": 77, "xla": 77, "is_xpu": 77, "xpu": 77, "isclos": 77, "isfinit": 77, "isinf": 77, "isnan": 77, "isneginf": 77, "isposinf": 77, "isreal": 77, "istft": 77, "n_fft": 77, "hop_length": 77, "win_length": 77, "window": 77, "center": 77, "onesid": 77, "return_complex": 77, "tolist": 77, "items": 77, "kron": 77, "kthvalu": 77, "lcm": 77, "lcm_": 77, "ldexp": 77, "ldexp_": 77, "le": 77, "le_": 77, "lerp": 77, "lerp_": 77, "less": 77, "lt": 77, "less_": 77, "less_equ": 77, "less_equal_": 77, "lgamma": 77, "lgamma_": 77, "log10": 77, "log10_": 77, "log1p": 77, "log1p_": 77, "log2": 77, "log2_": 77, "log_": 77, "log_normal_": 77, "parameter": 77, "ln": 77, "logaddexp": 77, "logaddexp2": 77, "logcumsumexp": 77, "logdet": 77, "logical_and": 77, "logical_and_": 77, "logical_not_": 77, "logical_or": 77, "logical_or_": 77, "logical_xor": 77, "logical_xor_": 77, "logit_": 77, "logsumexp": 77, "lt_": 77, "lu": 77, "pivot": 77, "get_info": 77, "lu_solv": 77, "lu_data": 77, "lu_pivot": 77, "map_": 77, "masked_fil": 77, "booltensor": 77, "masked_scatt": 77, "masked_scatter_": 77, "continu": 77, "occurr": 77, "mani": [77, 92, 100, 106, 107], "masked_select": 77, "matmul": [77, 89, 90, 104], "matrix_exp": 77, "matrix_pow": 77, "linalg": 77, "module_load": 77, "get_swap_module_params_on_convers": 77, "buffer": 77, "remap": 77, "swap_tensor": 77, "moveaxi": 77, "movedim": 77, "msort": 77, "mul": 77, "mul_": 77, "multinomi": 77, "num_sampl": 77, "multiply_": 77, "mv": 77, "mvlgamma": 77, "mvlgamma_": 77, "idx": [77, 100], "unnam": 77, "charact": [77, 106], "underscor": 77, "variabl": [77, 100], "nan_to_num": 77, "nan": 77, "posinf": 77, "neginf": 77, "nan_to_num_": 77, "nanmean": 77, "nanmedian": 77, "nanquantil": 77, "q": [77, 106], "nansum": 77, "narrow": 77, "narrow_copi": 77, "nbyte": 77, "consum": 77, "ndimens": 77, "ne": 77, "ne_": 77, "neg_": 77, "negative_": 77, "nelement": 77, "new_empti": 77, "pin_memori": 77, "uniniti": 77, "record": [77, 91, 100, 101], "would": [77, 100], "8182e": 77, "5765e": 77, "41": 77, "0545e": 77, "0949e": 77, "4842e": 77, "0000e": 77, "00": 77, "new_empty_strid": 77, "new_ful": 77, "141592": 77, "1416": 77, "new_on": 77, "new_tensor": 77, "want": [77, 105], "numpi": [77, 100], "arrai": 77, "from_numpi": 77, "read": [77, 107], "array_lik": 77, "new_zero": 77, "nextaft": 77, "nextafter_": 77, "nonzero": 77, "nonzero_stat": 77, "count": 77, "truncat": 77, "smaller": [77, 107], "invalid": 77, "input_tensor": 77, "static_s": 77, "rank": 77, "fro": 77, "normal_": 77, "not_equ": 77, "not_equal_": 77, "forc": 77, "ndarrai": 77, "convers": [77, 107], "reflect": [77, 102, 107], "vice": 77, "versa": 77, "resolve_conj": 77, "resolve_neg": 77, "isn": [77, 106], "won": 77, "shorthand": 77, "orgqr": 77, "ormqr": 77, "input3": 77, "outer": 77, "pinvers": 77, "polygamma": 77, "polygamma_": 77, "pow": 77, "pow_": 77, "put_": 77, "q_per_channel_axi": 77, "q_per_channel_scal": 77, "q_per_channel_zero_point": 77, "zero_point": 77, "q_scale": 77, "q_zero_point": 77, "qr": 77, "qscheme": 77, "qtensor": 77, "quantil": 77, "rad2deg": 77, "rad2deg_": 77, "discret": 77, "bound": 77, "53": 77, "ravel": 77, "reciproc": 77, "reciprocal_": 77, "record_stream": 77, "mark": [77, 100], "dealloc": [77, 101], "reus": 77, "until": [77, 107], "queu": 77, "complet": [77, 88, 105], "cach": [77, 85, 94], "awar": 77, "correctli": 77, "life": 77, "cycl": 77, "But": [77, 106], "unexpectedli": 77, "let": 77, "know": [77, 100, 106], "suitabl": 77, "side": 77, "abl": [77, 92, 107], "think": [77, 106], "carefulli": 77, "safeti": 77, "These": [77, 91, 107], "analog": 77, "tradeoff": 77, "gc": 77, "situat": 77, "lifetim": 77, "poll": 77, "race": 77, "creation": 77, "sync": 77, "back": 77, "suffici": [77, 92, 107], "realloc": 77, "done": [77, 92], "counterintuit": 77, "old": [77, 107], "becaus": [77, 107], "wait": 77, "concret": [77, 101], "s0": 77, "s1": 77, "wait_stream": 77, "some_comm_op": 77, "synchron": 77, "del": 77, "decid": 77, "immedi": 77, "wouldn": 77, "finish": 77, "profil": 77, "chrome": 77, "trace": [77, 92], "produc": [77, 107], "export_chrome_trac": 77, "earli": 77, "block": [77, 106, 107], "overlap": 77, "commun": 77, "late": 77, "live": 77, "guidanc": 77, "fsdp": 77, "cudacachingalloc": 77, "refin": 77, "special": [77, 92, 105], "renam": 77, "lift": 77, "coexist": 77, "nice": 77, "greedili": 77, "named_img": 77, "register_hook": 77, "execut": [77, 85, 94, 101], "register_post_accumulate_grad_hook": 77, "unless": [77, 100], "enable_grad": 77, "0100": 77, "0200": 77, "0300": 77, "remaind": 77, "remainder_": 77, "rename_map": 77, "position": 77, "drop": [77, 107], "One": 77, "renamed_img": 77, "height": 77, "width": 77, "rename_": 77, "maxnorm": 77, "renorm_": 77, "repeat": [77, 107], "similar": [77, 100, 107], "tile": 77, "repeat_interleav": 77, "output_s": 77, "fact": [77, 107], "tell": 77, "obtain": 77, "dataload": 77, "preprocess": 77, "sai": 77, "saved_weight": 77, "25": 77, "loaded_weight": 77, "5503": 77, "4926": 77, "1158": 77, "8303": 77, "1007": 77, "9853": 77, "2316": 77, "6606": 77, "compat": [77, 101], "reshape_a": 77, "resize_": 77, "resiz": 77, "fit": 77, "preserv": [77, 106, 107], "level": [77, 107], "reinterpret": 77, "unchang": [77, 80], "custom": [77, 92, 100], "set_": 77, "use_deterministic_algorithm": 77, "fill_uninitialized_memori": 77, "go": [77, 100, 106, 107], "unaffect": 77, "resize_as_": 77, "retains_grad": 77, "roll": 77, "rot90": 77, "decim": 77, "round_": 77, "rsqrt": 77, "rsqrt_": 77, "scatter": 77, "scatter_": 77, "manner": 77, "moreov": 77, "inclus": 77, "uniqu": 77, "pick": 77, "arbitrarili": 77, "propag": 77, "scatter_add_": 77, "scatter_reduce_": 77, "23": 77, "4600": 77, "2300": 77, "scatter_add": 77, "fashion": 77, "scatter_reduc": 77, "select_scatt": 77, "sgn": 77, "sgn_": 77, "share_memory_": 77, "untypedstorag": 77, "short": 77, "int16": 77, "sigmoid_": 77, "sign_": 77, "signbit": 77, "sin": 77, "sin_": 77, "sinc_": 77, "sinh": 77, "sinh_": 77, "slice_scatt": 77, "slogdet": 77, "smm": 77, "sort": [77, 100], "sparse_mask": 77, "filter": 77, "advis": 77, "whose": [77, 102], "nse": 77, "cat": 77, "sparse_coo_tensor": 77, "6550": 77, "2397": 77, "1611": 77, "0779": 77, "2326": 77, "0558": 77, "4711": 77, "9678": 77, "5138": 77, "0411": 77, "9417": 77, "5158": 77, "0793": 77, "0036": 77, "2569": 77, "1055": 77, "sparse_coo": 77, "sparse_resize_": 77, "sparse_resize_and_clear_": 77, "split_siz": 77, "sqrt_": 77, "square_": 77, "squeez": 77, "squeeze_": 77, "sspaddmm": 77, "stft": 77, "pad_mod": 77, "typedstorag": 77, "directli": [77, 100, 105, 107], "untyped_storag": 77, "storage_typ": 77, "jump": 77, "next": [77, 107], "sub": 77, "sub_": 77, "subtract_": 77, "sum_to_s": 77, "svd": 77, "compute_uv": 77, "swapax": 77, "axis0": 77, "axis1": 77, "swapaxes_": 77, "swapdim": 77, "dim0": 77, "swapdims_": 77, "t_": 77, "take_along_dim": 77, "tan": 77, "tan_": 77, "tanh_": 77, "tensor_split": 77, "indices_or_sect": 77, "5044": 77, "0005": 77, "3310": 77, "0584": 77, "cuda0": 77, "to_dens": 77, "masked_grad": 77, "to_mkldnn": 77, "mkldnn": 77, "to_padded_tensor": 77, "to_spars": 77, "sparsedim": 77, "coordin": 77, "blocksiz": 77, "could": [77, 107], "sparse_csc": 77, "sparse_bsr": 77, "sparse_bsc": 77, "bsr": 77, "bsc": 77, "runtimeerror": [77, 100], "except": 77, "evenli": 77, "csc": 77, "minu": 77, "divis": [77, 107], "sparsecsr": 77, "to_sparse_bsc": 77, "row_indic": 77, "ccol_indic": 77, "to_sparse_bsr": 77, "to_sparse_coo": 77, "_nnz": 77, "to_sparse_csc": 77, "nest": [77, 85, 92, 94], "012766935862600803": 77, "5415473580360413": 77, "08909505605697632": 77, "7729271650314331": 77, "topk": 77, "largest": 77, "transpose_": 77, "triangular_solv": 77, "unitriangular": 77, "tril_": 77, "triu": 77, "triu_": 77, "true_divid": 77, "true_divide_": 77, "trunc": 77, "trunc_": 77, "async": 77, "type_a": 77, "unbind": 77, "seq": 77, "unflatten": 77, "unfold": 77, "sizedim": 77, "happen": [77, 100, 106], "uniform_": 77, "return_invers": 77, "return_count": 77, "unique_consecut": 77, "elimin": [77, 107], "consecut": 77, "unsafe_chunk": 77, "unsafe_split": 77, "unsqueez": 77, "unsqueeze_": 77, "vdot": 77, "subspac": 77, "across": 77, "satisfi": 77, "condit": 77, "foral": 77, "unclear": 77, "z": [77, 100, 106], "2nd": 77, "3rd": 77, "proportion": 77, "twice": 77, "met": 77, "overload": 77, "torchscript": [77, 107], "program": [77, 107], "9482": 77, "0310": 77, "4999": 77, "5316": 77, "1520": 77, "7472": 77, "5617": 77, "8649": 77, "4724": 77, "0334": 77, "2976": 77, "8499": 77, "2109": 77, "9913": 77, "9607": 77, "6123": 77, "1064483442": 77, "1124191867": 77, "1069546515": 77, "1089989247": 77, "1105482831": 77, "1061112040": 77, "1057999968": 77, "1084397505": 77, "1071760287": 77, "1123489973": 77, "1097310419": 77, "1084649136": 77, "1101533110": 77, "1073668768": 77, "1082790149": 77, "1088634448": 77, "1000000000": 77, "0047": 77, "0310j": 77, "5316j": 77, "7472j": 77, "8649j": 77, "0334j": 77, "8499j": 77, "9913j": 77, "6123j": 77, "202": 77, "154": 77, "59": 77, "182": 77, "243": 77, "253": 77, "188": 77, "185": 77, "252": 77, "191": 77, "63": 77, "240": 77, "227": 77, "165": 77, "27": 77, "190": 77, "146": 77, "203": 77, "15": 77, "106": 77, "93": 77, "205": 77, "192": 77, "112": 77, "206": 77, "189": 77, "95": 77, "152": 77, "147": 77, "89": 77, "43": 77, "246": 77, "87": 77, "235": [77, 102], "226": 77, "254": 77, "111": 77, "117": 77, "177": [77, 107], "28": 77, "view_a": 77, "vsplit": 77, "xlogi": 77, "xlogy_": 77, "typeguard": 78, "dynamo": 83, "fwd_tensor": 84, "fwd": [84, 107], "bwd": [84, 102, 107], "slightli": [85, 87, 106, 107], "doesn": [85, 92, 94, 95, 97, 100], "unit_scal": [85, 91, 104, 105], "introduc": [85, 106, 107], "compos": 85, "_dynamo": [85, 94, 95], "optimis": [85, 94, 95, 104], "thereaft": [85, 94], "written": 85, "successfulli": 85, "rather": 85, "graphmodul": [85, 94, 101], "simulate_fp8": [85, 91, 92, 104, 105], "node": [86, 87, 88, 91, 96, 101], "suppli": [86, 87, 89, 92, 94, 100, 107], "pruned_graph": [86, 87], "52587890625e": 87, "negligibli": 87, "toler": 87, "signific": [87, 107], "onc": [88, 100], "fwd_format": 89, "bwd_format": 89, "torchdynamo": [89, 90, 92, 94, 95, 97], "scaled_dot_product_attent": [89, 90, 104], "inspect": [89, 90], "fp32": [89, 90, 107], "speedup": [89, 90, 107], "variou": [89, 91], "literatur": 90, "noun": 90, "et": 90, "al": 90, "2022": 90, "micikeviciu": 90, "e4": [90, 107], "e5": [90, 107], "analysi": [91, 92, 102, 103, 104, 107], "prune_non_float_tensor": [91, 104], "prune_same_scale_tensor": [91, 104], "tend": [91, 107], "procedur": 92, "recurs": [92, 94, 95, 97, 101, 102], "build": [92, 107], "fundament": 92, "proce": 92, "five": 92, "stage": 92, "identif": 92, "compar": 92, "unconstrain": 92, "proof": 92, "initialis": [92, 106, 107], "approach": [92, 94, 107], "compil": [92, 94, 95, 104, 105, 107], "own": [92, 94], "system": [92, 94], "easi": [92, 94], "interoper": [92, 94], "definit": 92, "basic": [92, 107], "told": 92, "explicitli": [92, 106], "substitut": 92, "new_gelu": 92, "test": [92, 105, 106, 107], "said": 92, "anticip": 92, "ultim": 92, "alon": 92, "prioriti": 92, "non_recurse_funct": 94, "graph_modul": 94, "backend_1": 94, "backend_2": 94, "_modul": 94, "torch_nn_modules_to_user_modul": [95, 104], "patch": 95, "target_fn": 96, "keep_type_expr": 96, "accompani": [96, 106], "retain": [96, 106], "mod": 97, "trivial_subclass": 97, "develop": [98, 107], "dataclass": 99, "scalepair": [100, 104], "ctx": 100, "functionctx": 100, "vjp": 100, "needs_input_grad": 100, "jvp": 100, "grad_input": 100, "got": 100, "mark_dirti": 100, "setup_context": 100, "matter": 100, "torch_doctest_autograd": 100, "staticmethod": 100, "x_npy": 100, "once_differenti": 100, "grad_output": 100, "lead": [100, 107], "wrong": 100, "mark_non_differenti": 100, "save_for_backward": 100, "g1": 100, "g2": 100, "saved_tensor": 100, "zeros_lik": 100, "oppos": 100, "leak": 100, "saved_tensors_hook": 100, "intermediari": 100, "neither": 100, "nor": 100, "recomput": 100, "tutori": [100, 107], "weren": 100, "grad_out": 100, "gx": 100, "gy": 100, "gz": 100, "save_for_forward": 100, "x_t": 100, "y_t": 100, "fwad": 100, "dual_level": 100, "a_dual": 100, "make_du": 100, "set_materialize_grad": 100, "materi": 100, "simplefunc": 100, "No": 100, "induc": 100, "insid": 100, "vmap": 100, "info": 100, "in_dim": 100, "underneath": 100, "generate_vmap_rul": 100, "choos": 100, "out_dim": 100, "instrument": 101, "boxed_run": 101, "args_list": 101, "interpret": 101, "promptli": 101, "call_funct": 101, "invoc": 101, "call_method": 101, "opoverload": 101, "call_modul": 101, "fetch_args_kwargs_from_env": 101, "fetch": 101, "environ": 101, "fetch_attr": 101, "hierarchi": 101, "qualifi": 101, "get_attr": 101, "Will": 101, "map_nodes_to_valu": 101, "belong": 101, "report": 101, "realli": 101, "referenc": 101, "placehold": 101, "tracer": 101, "target_to_funct": 101, "initial_env": 101, "enable_io_process": 101, "partial": 101, "process_input": 101, "process_output": 101, "run_nod": 101, "recurse_modul": 102, "syntax_highlight": 102, "autowrap_modul": 102, "einop": 102, "home": 102, "runner": 102, "lib": 102, "site": 102, "packag": 102, "py": 102, "autowrap_funct": 102, "dummi": 102, "union": 102, "fed": [102, 107], "plain": 102, "toggl": 102, "behavour": 102, "moduletyp": 102, "fc1": 102, "fc2": 102, "236": 102, "fc1_weight": 102, "018": [102, 107], "54": 102, "fc1_bia": 102, "0182": 102, "51": 102, "_c": [102, 107], "_nn": [102, 107], "578": [102, 107], "204": [102, 107], "337": 102, "288": 102, "fc2_weight": 102, "00902": [102, 107], "13": 102, "fc2_bia": 102, "00904": 102, "31": 102, "linear_1": [102, 107], "welcom": 104, "design": [104, 106, 107], "facilit": 104, "outlin": [104, 107], "icml": [104, 107], "broad": 104, "fork": [104, 107], "repo": [104, 107], "instruct": [104, 107], "consider": 104, "licens": 104, "blog": 104, "almost": 104, "depthmodulelist": 104, "depthsequenti": 104, "linearreadout": 104, "mhsa": 104, "transformerdecod": 104, "transformerlay": 104, "graph_to_datafram": 104, "apply_constraint": 104, "format_to_tupl": 104, "tuple_to_format": 104, "cross_entropi": [104, 107], "linear_readout": 104, "mse_loss": 104, "residual_appli": 104, "rms_norm": 104, "silu_glu": 104, "lr_scale_for_depth": 104, "lr_scale_func_adam": 104, "scale_bwd": 104, "scale_fwd": 104, "prune_selected_nod": 104, "simulate_format": 104, "apply_transform": 104, "patch_to_expand_modul": 104, "replace_node_with_funct": 104, "analyse_modul": [104, 107], "scaletrack": 104, "scaletrackinginterpret": 104, "logarithmic_interpol": 104, "scale_elementwis": 104, "despit": 105, "best": [105, 107], "effort": 105, "free": 105, "assist": 105, "anyon": 105, "issu": [105, 107], "coverag": 105, "ve": [105, 106], "focuss": 105, "difficulti": 105, "although": [105, 107], "suspect": 105, "haven": 105, "exhaust": 105, "encourag": 105, "touch": 105, "tl": 106, "dr": 106, "good": [106, 107], "thing": 106, "roughli": [106, 107], "behaviour": [106, 107], "satur": 106, "stabl": 106, "prime": 106, "color": 106, "green": 106, "insuffici": 106, "red": 106, "ll": 106, "explain": 106, "dynam": [106, 107], "wors": 106, "condens": 106, "summari": 106, "sim": 106, "infti": 106, "flat": 106, "uncorrel": 106, "spike": 106, "assumpt": [106, 107], "companion": 106, "find": [106, 107], "propos": 106, "autoregress": 106, "languag": 106, "shakespear": 106, "saw": 106, "sweep": 106, "unfortun": 106, "tini": 106, "shakespar": 106, "bert": 106, "intrigu": 106, "presum": 106, "accid": 106, "turn": 106, "solut": 106, "bad": 106, "reproduc": 106, "inde": 106, "care": [106, 107], "principl": 106, "underpin": 106, "far": [106, 107], "question": 106, "interest": 106, "With": 106, "thank": 106, "charli": 106, "blake": 106, "feedback": 106, "douglaso": 106, "ai": 106, "cover": 107, "brief": 107, "discuss": 107, "paradigm": 107, "aim": 107, "involv": 107, "scratch": 107, "advantag": 107, "great": 107, "headroom": 107, "grow": 107, "shrink": 107, "underflow": 107, "drift": 107, "fp16": 107, "decreas": 107, "treatment": 107, "motiv": 107, "sens": 107, "bf16": 107, "veri": 107, "larg": 107, "3e": 107, "38": 107, "45": 107, "60": 107, "000": 107, "6e": 107, "speed": 107, "scope": 107, "tricki": 107, "easier": 107, "breakdown": 107, "alongsid": 107, "unscaledmlp": 107, "linear_2": 107, "annotated_cod": 107, "linear_1_weight": 107, "83": 107, "linear_1_bia": 107, "84": 107, "322": 107, "289": 107, "linear_2_weight": 107, "48": 107, "linear_2_bia": 107, "00894": 107, "198": 107, "firstli": 107, "decompos": 107, "secondli": 107, "fwd_scale": 107, "bwd_scale": 107, "enough": 107, "unscal": 107, "scaledmlp": 107, "716": 107, "729": 107, "707": 107, "706": 107, "693": 107, "03": 107, "979": 107, "art": 107, "aris": 107, "clearli": 107, "explod": 107, "degrad": 107, "steadili": 107, "meet": 107, "concern": 107, "merit": 107, "investig": 107, "attain": 107, "substanti": 107, "push": 107, "themselv": 107, "solv": 107, "outsid": 107, "separ": 107, "residuallay": 107, "contrast": 107, "50": 107, "down": 107, "emploi": 107, "trick": 107, "comprehens": 107, "scenario": 107, "arriv": 107, "fan_in": 107, "fan_out": 107, "grad_weight_scal": 107, "grad_bias_scal": 107, "ideal": 107, "compromis": 107, "eager": 107, "trip": 107, "fortun": 107, "hi": 107, "overhead": 107, "fusion": 107, "answer": 107, "jit": 107, "script": 107, "rectifi": 107, "flexibl": 107, "unit_scaled_funct": 107, "unitscaledmodul": 107, "incur": 107, "naiv": 107, "benchmark": 107, "thorough": 107, "strongli": 107, "latest": 107, "recent": 107, "upgrad": 107, "preview": 107, "nightli": 107}, "objects": {"": [[3, 0, 0, "-", "unit_scaling"]], "unit_scaling": [[4, 1, 1, "", "CrossEntropyLoss"], [5, 1, 1, "", "DepthModuleList"], [6, 1, 1, "", "DepthSequential"], [7, 1, 1, "", "Dropout"], [8, 1, 1, "", "Embedding"], [9, 1, 1, "", "GELU"], [10, 1, 1, "", "LayerNorm"], [11, 1, 1, "", "Linear"], [12, 1, 1, "", "LinearReadout"], [13, 1, 1, "", "MHSA"], [14, 1, 1, "", "MLP"], [15, 4, 1, "", "Parameter"], [16, 1, 1, "", "RMSNorm"], [17, 1, 1, "", "SiLU"], [18, 1, 1, "", "Softmax"], [19, 1, 1, "", "TransformerDecoder"], [20, 1, 1, "", "TransformerLayer"], [21, 0, 0, "-", "analysis"], [26, 0, 0, "-", "constraints"], [35, 0, 0, "-", "core"], [41, 0, 0, "-", "formats"], [45, 0, 0, "-", "functional"], [64, 0, 0, "-", "optim"], [72, 0, 0, "-", "parameter"], [79, 0, 0, "-", "scale"], [82, 4, 1, "", "transformer_residual_scaling_rule"], [83, 0, 0, "-", "transforms"], [98, 0, 0, "-", "utils"], [103, 4, 1, "", "visualiser"]], "unit_scaling.DepthModuleList": [[5, 2, 1, "", "append"], [5, 2, 1, "", "extend"], [5, 2, 1, "", "insert"]], "unit_scaling.DepthSequential": [[6, 2, 1, "", "append"]], "unit_scaling.Embedding": [[8, 2, 1, "", "from_pretrained"], [8, 3, 1, "", "weight"]], "unit_scaling.LayerNorm": [[10, 3, 1, "", "bias"], [10, 3, 1, "", "weight"]], "unit_scaling.Linear": [[11, 3, 1, "", "bias"], [11, 3, 1, "", "weight"]], "unit_scaling.LinearReadout": [[12, 3, 1, "", "bias"], [12, 3, 1, "", "weight"]], "unit_scaling.RMSNorm": [[16, 3, 1, "", "weight"]], "unit_scaling.TransformerDecoder": [[19, 2, 1, "", "append"]], "unit_scaling.analysis": [[22, 4, 1, "", "example_batch"], [23, 4, 1, "", "graph_to_dataframe"], [24, 4, 1, "", "plot"], [25, 4, 1, "", "visualiser"]], "unit_scaling.constraints": [[27, 4, 1, "", "amean"], [28, 4, 1, "", "apply_constraint"], [29, 4, 1, "", "gmean"], [30, 4, 1, "", "hmean"], [31, 4, 1, "", "to_grad_input_scale"], [32, 4, 1, "", "to_left_grad_scale"], [33, 4, 1, "", "to_output_scale"], [34, 4, 1, "", "to_right_grad_scale"]], "unit_scaling.core": [[36, 0, 0, "-", "functional"]], "unit_scaling.core.functional": [[37, 4, 1, "", "logarithmic_interpolation"], [38, 4, 1, "", "rms"], [39, 4, 1, "", "scale_elementwise"], [40, 4, 1, "", "transformer_residual_scaling_rule"]], "unit_scaling.formats": [[42, 1, 1, "", "FPFormat"], [43, 4, 1, "", "format_to_tuple"], [44, 4, 1, "", "tuple_to_format"]], "unit_scaling.formats.FPFormat": [[42, 5, 1, "", "bits"], [42, 5, 1, "", "max_absolute_value"], [42, 5, 1, "", "min_absolute_normal"], [42, 5, 1, "", "min_absolute_subnormal"], [42, 2, 1, "", "quantise"], [42, 2, 1, "", "quantise_bwd"], [42, 2, 1, "", "quantise_fwd"]], "unit_scaling.functional": [[46, 4, 1, "", "add"], [47, 4, 1, "", "cross_entropy"], [48, 4, 1, "", "dropout"], [49, 4, 1, "", "embedding"], [50, 4, 1, "", "gelu"], [51, 4, 1, "", "layer_norm"], [52, 4, 1, "", "linear"], [53, 4, 1, "", "linear_readout"], [54, 4, 1, "", "matmul"], [55, 4, 1, "", "mse_loss"], [56, 4, 1, "", "residual_add"], [57, 4, 1, "", "residual_apply"], [58, 4, 1, "", "residual_split"], [59, 4, 1, "", "rms_norm"], [60, 4, 1, "", "scaled_dot_product_attention"], [61, 4, 1, "", "silu"], [62, 4, 1, "", "silu_glu"], [63, 4, 1, "", "softmax"]], "unit_scaling.optim": [[65, 1, 1, "", "Adam"], [66, 1, 1, "", "AdamW"], [67, 1, 1, "", "SGD"], [68, 4, 1, "", "lr_scale_for_depth"], [69, 4, 1, "", "lr_scale_func_adam"], [70, 4, 1, "", "lr_scale_func_sgd"], [71, 4, 1, "", "scaled_parameters"]], "unit_scaling.optim.Adam": [[65, 2, 1, "", "add_param_group"], [65, 2, 1, "", "load_state_dict"], [65, 2, 1, "", "register_load_state_dict_post_hook"], [65, 2, 1, "", "register_load_state_dict_pre_hook"], [65, 2, 1, "", "register_state_dict_post_hook"], [65, 2, 1, "", "register_state_dict_pre_hook"], [65, 2, 1, "", "register_step_post_hook"], [65, 2, 1, "", "register_step_pre_hook"], [65, 2, 1, "", "state_dict"], [65, 2, 1, "", "step"], [65, 2, 1, "", "zero_grad"]], "unit_scaling.optim.AdamW": [[66, 2, 1, "", "add_param_group"], [66, 2, 1, "", "load_state_dict"], [66, 2, 1, "", "register_load_state_dict_post_hook"], [66, 2, 1, "", "register_load_state_dict_pre_hook"], [66, 2, 1, "", "register_state_dict_post_hook"], [66, 2, 1, "", "register_state_dict_pre_hook"], [66, 2, 1, "", "register_step_post_hook"], [66, 2, 1, "", "register_step_pre_hook"], [66, 2, 1, "", "state_dict"], [66, 2, 1, "", "step"], [66, 2, 1, "", "zero_grad"]], "unit_scaling.optim.SGD": [[67, 2, 1, "", "add_param_group"], [67, 2, 1, "", "load_state_dict"], [67, 2, 1, "", "register_load_state_dict_post_hook"], [67, 2, 1, "", "register_load_state_dict_pre_hook"], [67, 2, 1, "", "register_state_dict_post_hook"], [67, 2, 1, "", "register_state_dict_pre_hook"], [67, 2, 1, "", "register_step_post_hook"], [67, 2, 1, "", "register_step_pre_hook"], [67, 2, 1, "", "state_dict"], [67, 2, 1, "", "step"], [67, 2, 1, "", "zero_grad"]], "unit_scaling.parameter": [[73, 1, 1, "", "OrderedDict"], [74, 4, 1, "", "Parameter"], [75, 1, 1, "", "ParameterData"], [76, 1, 1, "", "Protocol"], [77, 1, 1, "", "Tensor"], [78, 4, 1, "", "has_parameter_data"]], "unit_scaling.parameter.OrderedDict": [[73, 2, 1, "", "clear"], [73, 2, 1, "", "copy"], [73, 2, 1, "", "fromkeys"], [73, 2, 1, "", "get"], [73, 2, 1, "", "items"], [73, 2, 1, "", "keys"], [73, 2, 1, "", "move_to_end"], [73, 2, 1, "", "pop"], [73, 2, 1, "", "popitem"], [73, 2, 1, "", "setdefault"], [73, 2, 1, "", "update"], [73, 2, 1, "", "values"]], "unit_scaling.parameter.Tensor": [[77, 3, 1, "", "H"], [77, 3, 1, "", "T"], [77, 2, 1, "", "abs"], [77, 2, 1, "", "abs_"], [77, 2, 1, "", "absolute"], [77, 2, 1, "", "absolute_"], [77, 2, 1, "", "acos"], [77, 2, 1, "", "acos_"], [77, 2, 1, "", "acosh"], [77, 2, 1, "", "acosh_"], [77, 2, 1, "", "add"], [77, 2, 1, "", "add_"], [77, 2, 1, "", "addbmm"], [77, 2, 1, "", "addbmm_"], [77, 2, 1, "", "addcdiv"], [77, 2, 1, "", "addcdiv_"], [77, 2, 1, "", "addcmul"], [77, 2, 1, "", "addcmul_"], [77, 2, 1, "", "addmm"], [77, 2, 1, "", "addmm_"], [77, 2, 1, "", "addmv"], [77, 2, 1, "", "addmv_"], [77, 2, 1, "", "addr"], [77, 2, 1, "", "addr_"], [77, 2, 1, "", "adjoint"], [77, 2, 1, "", "align_as"], [77, 2, 1, "", "align_to"], [77, 2, 1, "", "all"], [77, 2, 1, "", "allclose"], [77, 2, 1, "", "amax"], [77, 2, 1, "", "amin"], [77, 2, 1, "", "aminmax"], [77, 2, 1, "", "angle"], [77, 2, 1, "", "any"], [77, 2, 1, "", "apply_"], [77, 2, 1, "", "arccos"], [77, 2, 1, "", "arccos_"], [77, 2, 1, "", "arccosh"], [77, 2, 1, "", "arccosh_"], [77, 2, 1, "", "arcsin"], [77, 2, 1, "", "arcsin_"], [77, 2, 1, "", "arcsinh"], [77, 2, 1, "", "arcsinh_"], [77, 2, 1, "", "arctan"], [77, 2, 1, "", "arctan2"], [77, 2, 1, "", "arctan2_"], [77, 2, 1, "", "arctan_"], [77, 2, 1, "", "arctanh"], [77, 2, 1, "", "arctanh_"], [77, 2, 1, "", "argmax"], [77, 2, 1, "", "argmin"], [77, 2, 1, "", "argsort"], [77, 2, 1, "", "argwhere"], [77, 2, 1, "", "as_strided"], [77, 2, 1, "", "as_strided_"], [77, 2, 1, "", "as_strided_scatter"], [77, 2, 1, "", "as_subclass"], [77, 2, 1, "", "asin"], [77, 2, 1, "", "asin_"], [77, 2, 1, "", "asinh"], [77, 2, 1, "", "asinh_"], [77, 2, 1, "", "atan"], [77, 2, 1, "", "atan2"], [77, 2, 1, "", "atan2_"], [77, 2, 1, "", "atan_"], [77, 2, 1, "", "atanh"], [77, 2, 1, "", "atanh_"], [77, 2, 1, "", "backward"], [77, 2, 1, "", "baddbmm"], [77, 2, 1, "", "baddbmm_"], [77, 2, 1, "", "bernoulli"], [77, 2, 1, "", "bernoulli_"], [77, 2, 1, "", "bfloat16"], [77, 2, 1, "", "bincount"], [77, 2, 1, "", "bitwise_and"], [77, 2, 1, "", "bitwise_and_"], [77, 2, 1, "", "bitwise_left_shift"], [77, 2, 1, "", "bitwise_left_shift_"], [77, 2, 1, "", "bitwise_not"], [77, 2, 1, "", "bitwise_not_"], [77, 2, 1, "", "bitwise_or"], [77, 2, 1, "", "bitwise_or_"], [77, 2, 1, "", "bitwise_right_shift"], [77, 2, 1, "", "bitwise_right_shift_"], [77, 2, 1, "", "bitwise_xor"], [77, 2, 1, "", "bitwise_xor_"], [77, 2, 1, "", "bmm"], [77, 2, 1, "", "bool"], [77, 2, 1, "", "broadcast_to"], [77, 2, 1, "", "byte"], [77, 2, 1, "", "cauchy_"], [77, 2, 1, "", "cdouble"], [77, 2, 1, "", "ceil"], [77, 2, 1, "", "ceil_"], [77, 2, 1, "", "cfloat"], [77, 2, 1, "", "chalf"], [77, 2, 1, "", "char"], [77, 2, 1, "", "cholesky"], [77, 2, 1, "", "cholesky_inverse"], [77, 2, 1, "", "cholesky_solve"], [77, 2, 1, "", "chunk"], [77, 2, 1, "", "clamp"], [77, 2, 1, "", "clamp_"], [77, 2, 1, "", "clip"], [77, 2, 1, "", "clip_"], [77, 2, 1, "", "clone"], [77, 2, 1, "", "coalesce"], [77, 2, 1, "", "col_indices"], [77, 2, 1, "", "conj"], [77, 2, 1, "", "conj_physical"], [77, 2, 1, "", "conj_physical_"], [77, 2, 1, "", "contiguous"], [77, 2, 1, "", "copy_"], [77, 2, 1, "", "copysign"], [77, 2, 1, "", "copysign_"], [77, 2, 1, "", "corrcoef"], [77, 2, 1, "", "cos"], [77, 2, 1, "", "cos_"], [77, 2, 1, "", "cosh"], [77, 2, 1, "", "cosh_"], [77, 2, 1, "", "count_nonzero"], [77, 2, 1, "", "cov"], [77, 2, 1, "", "cpu"], [77, 2, 1, "", "cross"], [77, 2, 1, "", "crow_indices"], [77, 2, 1, "", "cuda"], [77, 2, 1, "", "cummax"], [77, 2, 1, "", "cummin"], [77, 2, 1, "", "cumprod"], [77, 2, 1, "", "cumprod_"], [77, 2, 1, "", "cumsum"], [77, 2, 1, "", "cumsum_"], [77, 2, 1, "", "data_ptr"], [77, 2, 1, "", "deg2rad"], [77, 2, 1, "", "deg2rad_"], [77, 2, 1, "", "dense_dim"], [77, 2, 1, "", "dequantize"], [77, 2, 1, "", "det"], [77, 2, 1, "", "detach"], [77, 2, 1, "", "detach_"], [77, 3, 1, "", "device"], [77, 2, 1, "", "diag"], [77, 2, 1, "", "diag_embed"], [77, 2, 1, "", "diagflat"], [77, 2, 1, "", "diagonal"], [77, 2, 1, "", "diagonal_scatter"], [77, 2, 1, "", "diff"], [77, 2, 1, "", "digamma"], [77, 2, 1, "", "digamma_"], [77, 2, 1, "", "dim"], [77, 2, 1, "", "dim_order"], [77, 2, 1, "", "dist"], [77, 2, 1, "", "div"], [77, 2, 1, "", "div_"], [77, 2, 1, "", "divide"], [77, 2, 1, "", "divide_"], [77, 2, 1, "", "dot"], [77, 2, 1, "", "double"], [77, 2, 1, "", "dsplit"], [77, 2, 1, "", "element_size"], [77, 2, 1, "", "eq"], [77, 2, 1, "", "eq_"], [77, 2, 1, "", "equal"], [77, 2, 1, "", "erf"], [77, 2, 1, "", "erf_"], [77, 2, 1, "", "erfc"], [77, 2, 1, "", "erfc_"], [77, 2, 1, "", "erfinv"], [77, 2, 1, "", "erfinv_"], [77, 2, 1, "", "exp"], [77, 2, 1, "", "exp2"], [77, 2, 1, "", "exp2_"], [77, 2, 1, "", "exp_"], [77, 2, 1, "", "expand"], [77, 2, 1, "", "expand_as"], [77, 2, 1, "", "expm1"], [77, 2, 1, "", "expm1_"], [77, 2, 1, "", "exponential_"], [77, 2, 1, "", "fill_"], [77, 2, 1, "", "fill_diagonal_"], [77, 2, 1, "", "fix"], [77, 2, 1, "", "fix_"], [77, 2, 1, "", "flatten"], [77, 2, 1, "", "flip"], [77, 2, 1, "", "fliplr"], [77, 2, 1, "", "flipud"], [77, 2, 1, "", "float"], [77, 2, 1, "", "float_power"], [77, 2, 1, "", "float_power_"], [77, 2, 1, "", "floor"], [77, 2, 1, "", "floor_"], [77, 2, 1, "", "floor_divide"], [77, 2, 1, "", "floor_divide_"], [77, 2, 1, "", "fmax"], [77, 2, 1, "", "fmin"], [77, 2, 1, "", "fmod"], [77, 2, 1, "", "fmod_"], [77, 2, 1, "", "frac"], [77, 2, 1, "", "frac_"], [77, 2, 1, "", "frexp"], [77, 2, 1, "", "gather"], [77, 2, 1, "", "gcd"], [77, 2, 1, "", "gcd_"], [77, 2, 1, "", "ge"], [77, 2, 1, "", "ge_"], [77, 2, 1, "", "geometric_"], [77, 2, 1, "", "geqrf"], [77, 2, 1, "", "ger"], [77, 2, 1, "", "get_device"], [77, 3, 1, "", "grad"], [77, 2, 1, "", "greater"], [77, 2, 1, "", "greater_"], [77, 2, 1, "", "greater_equal"], [77, 2, 1, "", "greater_equal_"], [77, 2, 1, "", "gt"], [77, 2, 1, "", "gt_"], [77, 2, 1, "", "half"], [77, 2, 1, "", "hardshrink"], [77, 2, 1, "", "has_names"], [77, 2, 1, "", "heaviside"], [77, 2, 1, "", "heaviside_"], [77, 2, 1, "", "histc"], [77, 2, 1, "", "histogram"], [77, 2, 1, "", "hsplit"], [77, 2, 1, "", "hypot"], [77, 2, 1, "", "hypot_"], [77, 2, 1, "", "i0"], [77, 2, 1, "", "i0_"], [77, 2, 1, "", "igamma"], [77, 2, 1, "", "igamma_"], [77, 2, 1, "", "igammac"], [77, 2, 1, "", "igammac_"], [77, 3, 1, "", "imag"], [77, 2, 1, "", "index_add"], [77, 2, 1, "", "index_add_"], [77, 2, 1, "", "index_copy"], [77, 2, 1, "", "index_copy_"], [77, 2, 1, "", "index_fill"], [77, 2, 1, "", "index_fill_"], [77, 2, 1, "", "index_put"], [77, 2, 1, "", "index_put_"], [77, 2, 1, "", "index_reduce_"], [77, 2, 1, "", "index_select"], [77, 2, 1, "", "indices"], [77, 2, 1, "", "inner"], [77, 2, 1, "", "int"], [77, 2, 1, "", "int_repr"], [77, 2, 1, "", "inverse"], [77, 2, 1, "", "ipu"], [77, 2, 1, "", "is_coalesced"], [77, 2, 1, "", "is_complex"], [77, 2, 1, "", "is_conj"], [77, 2, 1, "", "is_contiguous"], [77, 3, 1, "", "is_cpu"], [77, 3, 1, "", "is_cuda"], [77, 2, 1, "", "is_floating_point"], [77, 2, 1, "", "is_inference"], [77, 3, 1, "", "is_ipu"], [77, 3, 1, "", "is_leaf"], [77, 3, 1, "", "is_meta"], [77, 3, 1, "", "is_mps"], [77, 2, 1, "", "is_neg"], [77, 2, 1, "", "is_pinned"], [77, 3, 1, "", "is_quantized"], [77, 2, 1, "", "is_set_to"], [77, 2, 1, "", "is_shared"], [77, 2, 1, "", "is_signed"], [77, 3, 1, "", "is_sparse"], [77, 3, 1, "", "is_sparse_csr"], [77, 3, 1, "", "is_xla"], [77, 3, 1, "", "is_xpu"], [77, 2, 1, "", "isclose"], [77, 2, 1, "", "isfinite"], [77, 2, 1, "", "isinf"], [77, 2, 1, "", "isnan"], [77, 2, 1, "", "isneginf"], [77, 2, 1, "", "isposinf"], [77, 2, 1, "", "isreal"], [77, 2, 1, "", "istft"], [77, 2, 1, "", "item"], [77, 3, 1, "", "itemsize"], [77, 2, 1, "", "kron"], [77, 2, 1, "", "kthvalue"], [77, 2, 1, "", "lcm"], [77, 2, 1, "", "lcm_"], [77, 2, 1, "", "ldexp"], [77, 2, 1, "", "ldexp_"], [77, 2, 1, "", "le"], [77, 2, 1, "", "le_"], [77, 2, 1, "", "lerp"], [77, 2, 1, "", "lerp_"], [77, 2, 1, "", "less"], [77, 2, 1, "", "less_"], [77, 2, 1, "", "less_equal"], [77, 2, 1, "", "less_equal_"], [77, 2, 1, "", "lgamma"], [77, 2, 1, "", "lgamma_"], [77, 2, 1, "", "log"], [77, 2, 1, "", "log10"], [77, 2, 1, "", "log10_"], [77, 2, 1, "", "log1p"], [77, 2, 1, "", "log1p_"], [77, 2, 1, "", "log2"], [77, 2, 1, "", "log2_"], [77, 2, 1, "", "log_"], [77, 2, 1, "", "log_normal_"], [77, 2, 1, "", "logaddexp"], [77, 2, 1, "", "logaddexp2"], [77, 2, 1, "", "logcumsumexp"], [77, 2, 1, "", "logdet"], [77, 2, 1, "", "logical_and"], [77, 2, 1, "", "logical_and_"], [77, 2, 1, "", "logical_not"], [77, 2, 1, "", "logical_not_"], [77, 2, 1, "", "logical_or"], [77, 2, 1, "", "logical_or_"], [77, 2, 1, "", "logical_xor"], [77, 2, 1, "", "logical_xor_"], [77, 2, 1, "", "logit"], [77, 2, 1, "", "logit_"], [77, 2, 1, "", "logsumexp"], [77, 2, 1, "", "long"], [77, 2, 1, "", "lt"], [77, 2, 1, "", "lt_"], [77, 2, 1, "", "lu"], [77, 2, 1, "", "lu_solve"], [77, 3, 1, "", "mH"], [77, 3, 1, "", "mT"], [77, 2, 1, "", "map_"], [77, 2, 1, "", "masked_fill"], [77, 2, 1, "", "masked_fill_"], [77, 2, 1, "", "masked_scatter"], [77, 2, 1, "", "masked_scatter_"], [77, 2, 1, "", "masked_select"], [77, 2, 1, "", "matmul"], [77, 2, 1, "", "matrix_exp"], [77, 2, 1, "", "matrix_power"], [77, 2, 1, "", "max"], [77, 2, 1, "", "maximum"], [77, 2, 1, "", "mean"], [77, 2, 1, "", "median"], [77, 2, 1, "", "min"], [77, 2, 1, "", "minimum"], [77, 2, 1, "", "mm"], [77, 2, 1, "", "mode"], [77, 2, 1, "", "module_load"], [77, 2, 1, "", "moveaxis"], [77, 2, 1, "", "movedim"], [77, 2, 1, "", "msort"], [77, 2, 1, "", "mul"], [77, 2, 1, "", "mul_"], [77, 2, 1, "", "multinomial"], [77, 2, 1, "", "multiply"], [77, 2, 1, "", "multiply_"], [77, 2, 1, "", "mv"], [77, 2, 1, "", "mvlgamma"], [77, 2, 1, "", "mvlgamma_"], [77, 3, 1, "", "names"], [77, 2, 1, "", "nan_to_num"], [77, 2, 1, "", "nan_to_num_"], [77, 2, 1, "", "nanmean"], [77, 2, 1, "", "nanmedian"], [77, 2, 1, "", "nanquantile"], [77, 2, 1, "", "nansum"], [77, 2, 1, "", "narrow"], [77, 2, 1, "", "narrow_copy"], [77, 3, 1, "", "nbytes"], [77, 3, 1, "", "ndim"], [77, 2, 1, "", "ndimension"], [77, 2, 1, "", "ne"], [77, 2, 1, "", "ne_"], [77, 2, 1, "", "neg"], [77, 2, 1, "", "neg_"], [77, 2, 1, "", "negative"], [77, 2, 1, "", "negative_"], [77, 2, 1, "", "nelement"], [77, 2, 1, "", "new_empty"], [77, 2, 1, "", "new_empty_strided"], [77, 2, 1, "", "new_full"], [77, 2, 1, "", "new_ones"], [77, 2, 1, "", "new_tensor"], [77, 2, 1, "", "new_zeros"], [77, 2, 1, "", "nextafter"], [77, 2, 1, "", "nextafter_"], [77, 2, 1, "", "nonzero"], [77, 2, 1, "", "nonzero_static"], [77, 2, 1, "", "norm"], [77, 2, 1, "", "normal_"], [77, 2, 1, "", "not_equal"], [77, 2, 1, "", "not_equal_"], [77, 2, 1, "", "numel"], [77, 2, 1, "", "numpy"], [77, 2, 1, "", "orgqr"], [77, 2, 1, "", "ormqr"], [77, 2, 1, "", "outer"], [77, 2, 1, "", "permute"], [77, 2, 1, "", "pin_memory"], [77, 2, 1, "", "pinverse"], [77, 2, 1, "", "polygamma"], [77, 2, 1, "", "polygamma_"], [77, 2, 1, "", "positive"], [77, 2, 1, "", "pow"], [77, 2, 1, "", "pow_"], [77, 2, 1, "", "prod"], [77, 2, 1, "", "put"], [77, 2, 1, "", "put_"], [77, 2, 1, "", "q_per_channel_axis"], [77, 2, 1, "", "q_per_channel_scales"], [77, 2, 1, "", "q_per_channel_zero_points"], [77, 2, 1, "", "q_scale"], [77, 2, 1, "", "q_zero_point"], [77, 2, 1, "", "qr"], [77, 2, 1, "", "qscheme"], [77, 2, 1, "", "quantile"], [77, 2, 1, "", "rad2deg"], [77, 2, 1, "", "rad2deg_"], [77, 2, 1, "", "random_"], [77, 2, 1, "", "ravel"], [77, 3, 1, "", "real"], [77, 2, 1, "", "reciprocal"], [77, 2, 1, "", "reciprocal_"], [77, 2, 1, "", "record_stream"], [77, 2, 1, "", "refine_names"], [77, 2, 1, "", "register_hook"], [77, 2, 1, "", "register_post_accumulate_grad_hook"], [77, 2, 1, "", "remainder"], [77, 2, 1, "", "remainder_"], [77, 2, 1, "", "rename"], [77, 2, 1, "", "rename_"], [77, 2, 1, "", "renorm"], [77, 2, 1, "", "renorm_"], [77, 2, 1, "", "repeat"], [77, 2, 1, "", "repeat_interleave"], [77, 3, 1, "", "requires_grad"], [77, 2, 1, "", "requires_grad_"], [77, 2, 1, "", "reshape"], [77, 2, 1, "", "reshape_as"], [77, 2, 1, "", "resize_"], [77, 2, 1, "", "resize_as_"], [77, 2, 1, "", "resolve_conj"], [77, 2, 1, "", "resolve_neg"], [77, 2, 1, "", "retain_grad"], [77, 3, 1, "", "retains_grad"], [77, 2, 1, "", "roll"], [77, 2, 1, "", "rot90"], [77, 2, 1, "", "round"], [77, 2, 1, "", "round_"], [77, 2, 1, "", "rsqrt"], [77, 2, 1, "", "rsqrt_"], [77, 2, 1, "", "scatter"], [77, 2, 1, "", "scatter_"], [77, 2, 1, "", "scatter_add"], [77, 2, 1, "", "scatter_add_"], [77, 2, 1, "", "scatter_reduce"], [77, 2, 1, "", "scatter_reduce_"], [77, 2, 1, "", "select"], [77, 2, 1, "", "select_scatter"], [77, 2, 1, "", "set_"], [77, 2, 1, "", "sgn"], [77, 2, 1, "", "sgn_"], [77, 3, 1, "", "shape"], [77, 2, 1, "", "share_memory_"], [77, 2, 1, "", "short"], [77, 2, 1, "", "sigmoid"], [77, 2, 1, "", "sigmoid_"], [77, 2, 1, "", "sign"], [77, 2, 1, "", "sign_"], [77, 2, 1, "", "signbit"], [77, 2, 1, "", "sin"], [77, 2, 1, "", "sin_"], [77, 2, 1, "", "sinc"], [77, 2, 1, "", "sinc_"], [77, 2, 1, "", "sinh"], [77, 2, 1, "", "sinh_"], [77, 2, 1, "", "size"], [77, 2, 1, "", "slice_scatter"], [77, 2, 1, "", "slogdet"], [77, 2, 1, "", "smm"], [77, 2, 1, "", "softmax"], [77, 2, 1, "", "sort"], [77, 2, 1, "", "sparse_dim"], [77, 2, 1, "", "sparse_mask"], [77, 2, 1, "", "sparse_resize_"], [77, 2, 1, "", "sparse_resize_and_clear_"], [77, 2, 1, "", "split"], [77, 2, 1, "", "sqrt"], [77, 2, 1, "", "sqrt_"], [77, 2, 1, "", "square"], [77, 2, 1, "", "square_"], [77, 2, 1, "", "squeeze"], [77, 2, 1, "", "squeeze_"], [77, 2, 1, "", "sspaddmm"], [77, 2, 1, "", "std"], [77, 2, 1, "", "stft"], [77, 2, 1, "", "storage"], [77, 2, 1, "", "storage_offset"], [77, 2, 1, "", "storage_type"], [77, 2, 1, "", "stride"], [77, 2, 1, "", "sub"], [77, 2, 1, "", "sub_"], [77, 2, 1, "", "subtract"], [77, 2, 1, "", "subtract_"], [77, 2, 1, "", "sum"], [77, 2, 1, "", "sum_to_size"], [77, 2, 1, "", "svd"], [77, 2, 1, "", "swapaxes"], [77, 2, 1, "", "swapaxes_"], [77, 2, 1, "", "swapdims"], [77, 2, 1, "", "swapdims_"], [77, 2, 1, "", "t"], [77, 2, 1, "", "t_"], [77, 2, 1, "", "take"], [77, 2, 1, "", "take_along_dim"], [77, 2, 1, "", "tan"], [77, 2, 1, "", "tan_"], [77, 2, 1, "", "tanh"], [77, 2, 1, "", "tanh_"], [77, 2, 1, "", "tensor_split"], [77, 2, 1, "", "tile"], [77, 2, 1, "", "to"], [77, 2, 1, "", "to_dense"], [77, 2, 1, "", "to_mkldnn"], [77, 2, 1, "", "to_padded_tensor"], [77, 2, 1, "", "to_sparse"], [77, 2, 1, "", "to_sparse_bsc"], [77, 2, 1, "", "to_sparse_bsr"], [77, 2, 1, "", "to_sparse_coo"], [77, 2, 1, "", "to_sparse_csc"], [77, 2, 1, "", "to_sparse_csr"], [77, 2, 1, "", "tolist"], [77, 2, 1, "", "topk"], [77, 2, 1, "", "trace"], [77, 2, 1, "", "transpose"], [77, 2, 1, "", "transpose_"], [77, 2, 1, "", "triangular_solve"], [77, 2, 1, "", "tril"], [77, 2, 1, "", "tril_"], [77, 2, 1, "", "triu"], [77, 2, 1, "", "triu_"], [77, 2, 1, "", "true_divide"], [77, 2, 1, "", "true_divide_"], [77, 2, 1, "", "trunc"], [77, 2, 1, "", "trunc_"], [77, 2, 1, "", "type"], [77, 2, 1, "", "type_as"], [77, 2, 1, "", "unbind"], [77, 2, 1, "", "unflatten"], [77, 2, 1, "", "unfold"], [77, 2, 1, "", "uniform_"], [77, 2, 1, "", "unique"], [77, 2, 1, "", "unique_consecutive"], [77, 2, 1, "", "unsafe_chunk"], [77, 2, 1, "", "unsafe_split"], [77, 2, 1, "", "unsqueeze"], [77, 2, 1, "", "unsqueeze_"], [77, 2, 1, "", "untyped_storage"], [77, 2, 1, "", "values"], [77, 2, 1, "", "var"], [77, 2, 1, "", "vdot"], [77, 2, 1, "", "view"], [77, 2, 1, "", "view_as"], [77, 2, 1, "", "vsplit"], [77, 2, 1, "", "where"], [77, 2, 1, "", "xlogy"], [77, 2, 1, "", "xlogy_"], [77, 2, 1, "", "xpu"], [77, 2, 1, "", "zero_"]], "unit_scaling.scale": [[80, 4, 1, "", "scale_bwd"], [81, 4, 1, "", "scale_fwd"]], "unit_scaling.transforms": [[84, 1, 1, "", "Metrics"], [85, 4, 1, "", "compile"], [86, 4, 1, "", "prune_non_float_tensors"], [87, 4, 1, "", "prune_same_scale_tensors"], [88, 4, 1, "", "prune_selected_nodes"], [89, 4, 1, "", "simulate_format"], [90, 4, 1, "", "simulate_fp8"], [91, 4, 1, "", "track_scales"], [92, 4, 1, "", "unit_scale"], [93, 0, 0, "-", "utils"]], "unit_scaling.transforms.Metrics": [[84, 1, 1, "", "Data"]], "unit_scaling.transforms.utils": [[94, 4, 1, "", "apply_transform"], [95, 4, 1, "", "patch_to_expand_modules"], [96, 4, 1, "", "replace_node_with_function"], [97, 4, 1, "", "torch_nn_modules_to_user_modules"]], "unit_scaling.utils": [[99, 1, 1, "", "ScalePair"], [100, 1, 1, "", "ScaleTracker"], [101, 1, 1, "", "ScaleTrackingInterpreter"], [102, 4, 1, "", "analyse_module"]], "unit_scaling.utils.ScaleTracker": [[100, 2, 1, "", "backward"], [100, 2, 1, "", "jvp"], [100, 2, 1, "", "mark_dirty"], [100, 2, 1, "", "mark_non_differentiable"], [100, 2, 1, "", "save_for_backward"], [100, 2, 1, "", "save_for_forward"], [100, 2, 1, "", "set_materialize_grads"], [100, 2, 1, "", "setup_context"], [100, 2, 1, "", "vjp"], [100, 2, 1, "", "vmap"]], "unit_scaling.utils.ScaleTrackingInterpreter": [[101, 2, 1, "", "boxed_run"], [101, 2, 1, "", "call_function"], [101, 2, 1, "", "call_method"], [101, 2, 1, "", "call_module"], [101, 2, 1, "", "fetch_args_kwargs_from_env"], [101, 2, 1, "", "fetch_attr"], [101, 2, 1, "", "get_attr"], [101, 2, 1, "", "map_nodes_to_values"], [101, 2, 1, "", "output"], [101, 2, 1, "", "placeholder"], [101, 2, 1, "", "run"], [101, 2, 1, "", "run_node"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:attribute", "4": "py:function", "5": "py:property"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "function", "Python function"], "5": ["py", "property", "Python property"]}, "titleterms": {"api": 0, "refer": 0, "unit": [1, 2, 104, 107], "scale": [1, 2, 79, 80, 81, 104, 106, 107], "blog": 1, "almost": [1, 106], "dot": [1, 106], "product": [1, 106], "self": 1, "attent": [1, 106], "maxim": 2, "updat": 2, "parameter": 2, "u": 2, "\u03bcp": 2, "instal": [2, 104, 107], "what": [2, 107], "i": [2, 107], "develop": [2, 104], "licens": 2, "unit_sc": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103], "crossentropyloss": 4, "depthmodulelist": 5, "depthsequenti": 6, "dropout": [7, 48], "embed": [8, 49], "gelu": [9, 50], "layernorm": 10, "linear": [11, 52], "linearreadout": 12, "mhsa": 13, "mlp": 14, "paramet": [15, 72, 73, 74, 75, 76, 77, 78], "rmsnorm": 16, "silu": [17, 61], "softmax": [18, 63], "transformerdecod": 19, "transformerlay": 20, "analysi": [21, 22, 23, 24, 25], "example_batch": 22, "graph_to_datafram": 23, "plot": 24, "visualis": [25, 103], "constraint": [26, 27, 28, 29, 30, 31, 32, 33, 34], "amean": 27, "apply_constraint": 28, "gmean": 29, "hmean": 30, "to_grad_input_scal": 31, "to_left_grad_scal": 32, "to_output_scal": 33, "to_right_grad_scal": 34, "core": [35, 36, 37, 38, 39, 40], "function": [36, 37, 38, 39, 40, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], "logarithmic_interpol": 37, "rm": 38, "scale_elementwis": 39, "transformer_residual_scaling_rul": [40, 82], "format": [41, 42, 43, 44], "fpformat": 42, "format_to_tupl": 43, "tuple_to_format": 44, "add": 46, "cross_entropi": 47, "layer_norm": 51, "linear_readout": 53, "matmul": 54, "mse_loss": 55, "residual_add": 56, "residual_appli": 57, "residual_split": 58, "rms_norm": 59, "scaled_dot_product_attent": 60, "silu_glu": 62, "optim": [64, 65, 66, 67, 68, 69, 70, 71], "adam": 65, "adamw": 66, "sgd": 67, "lr_scale_for_depth": 68, "lr_scale_func_adam": 69, "lr_scale_func_sgd": 70, "scaled_paramet": 71, "ordereddict": 73, "parameterdata": 75, "protocol": 76, "tensor": 77, "has_parameter_data": 78, "scale_bwd": 80, "scale_fwd": 81, "transform": [83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97], "metric": 84, "compil": 85, "prune_non_float_tensor": 86, "prune_same_scale_tensor": 87, "prune_selected_nod": 88, "simulate_format": 89, "simulate_fp8": 90, "track_scal": 91, "unit_scal": 92, "util": [93, 94, 95, 96, 97, 98, 99, 100, 101, 102], "apply_transform": 94, "patch_to_expand_modul": 95, "replace_node_with_funct": 96, "torch_nn_modules_to_user_modul": 97, "scalepair": 99, "scaletrack": 100, "scaletrackinginterpret": 101, "analyse_modul": 102, "get": 104, "start": 104, "content": 104, "limit": 105, "where": 106, "doe": 106, "d_": 106, "seq": 106, "e": 106, "1": 106, "2": 106, "come": 106, "from": 106, "work": 106, "No": 106, "conclus": 106, "user": 107, "guid": 107, "how": 107, "model": 107, "kei": 107, "consider": 107, "optimis": 107}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"API reference": [[0, "api-reference"]], "Unit Scaling blog": [[1, "unit-scaling-blog"]], "Almost scaled dot-product self attention": [[1, "almost-scaled-dot-product-self-attention"]], "Unit-Scaled Maximal Update Parameterization (u-\u03bcP)": [[2, "unit-scaled-maximal-update-parameterization-u-p"]], "Installation": [[2, "installation"], [104, "installation"], [107, "installation"]], "What is unit scaling?": [[2, "what-is-unit-scaling"], [107, "what-is-unit-scaling"]], "Development": [[2, "development"], [104, "development"]], "License": [[2, "license"]], "unit_scaling": [[3, "module-unit_scaling"]], "unit_scaling.CrossEntropyLoss": [[4, "unit-scaling-crossentropyloss"]], "unit_scaling.DepthModuleList": [[5, "unit-scaling-depthmodulelist"]], "unit_scaling.DepthSequential": [[6, "unit-scaling-depthsequential"]], "unit_scaling.Dropout": [[7, "unit-scaling-dropout"]], "unit_scaling.Embedding": [[8, "unit-scaling-embedding"]], "unit_scaling.GELU": [[9, "unit-scaling-gelu"]], "unit_scaling.LayerNorm": [[10, "unit-scaling-layernorm"]], "unit_scaling.Linear": [[11, "unit-scaling-linear"]], "unit_scaling.LinearReadout": [[12, "unit-scaling-linearreadout"]], "unit_scaling.MHSA": [[13, "unit-scaling-mhsa"]], "unit_scaling.MLP": [[14, "unit-scaling-mlp"]], "unit_scaling.Parameter": [[15, "unit-scaling-parameter"]], "unit_scaling.RMSNorm": [[16, "unit-scaling-rmsnorm"]], "unit_scaling.SiLU": [[17, "unit-scaling-silu"]], "unit_scaling.Softmax": [[18, "unit-scaling-softmax"]], "unit_scaling.TransformerDecoder": [[19, "unit-scaling-transformerdecoder"]], "unit_scaling.TransformerLayer": [[20, "unit-scaling-transformerlayer"]], "unit_scaling.analysis": [[21, "module-unit_scaling.analysis"]], "unit_scaling.analysis.example_batch": [[22, "unit-scaling-analysis-example-batch"]], "unit_scaling.analysis.graph_to_dataframe": [[23, "unit-scaling-analysis-graph-to-dataframe"]], "unit_scaling.analysis.plot": [[24, "unit-scaling-analysis-plot"]], "unit_scaling.analysis.visualiser": [[25, "unit-scaling-analysis-visualiser"]], "unit_scaling.constraints": [[26, "module-unit_scaling.constraints"]], "unit_scaling.constraints.amean": [[27, "unit-scaling-constraints-amean"]], "unit_scaling.constraints.apply_constraint": [[28, "unit-scaling-constraints-apply-constraint"]], "unit_scaling.constraints.gmean": [[29, "unit-scaling-constraints-gmean"]], "unit_scaling.constraints.hmean": [[30, "unit-scaling-constraints-hmean"]], "unit_scaling.constraints.to_grad_input_scale": [[31, "unit-scaling-constraints-to-grad-input-scale"]], "unit_scaling.constraints.to_left_grad_scale": [[32, "unit-scaling-constraints-to-left-grad-scale"]], "unit_scaling.constraints.to_output_scale": [[33, "unit-scaling-constraints-to-output-scale"]], "unit_scaling.constraints.to_right_grad_scale": [[34, "unit-scaling-constraints-to-right-grad-scale"]], "unit_scaling.core": [[35, "module-unit_scaling.core"]], "unit_scaling.core.functional": [[36, "module-unit_scaling.core.functional"]], "unit_scaling.core.functional.logarithmic_interpolation": [[37, "unit-scaling-core-functional-logarithmic-interpolation"]], "unit_scaling.core.functional.rms": [[38, "unit-scaling-core-functional-rms"]], "unit_scaling.core.functional.scale_elementwise": [[39, "unit-scaling-core-functional-scale-elementwise"]], "unit_scaling.core.functional.transformer_residual_scaling_rule": [[40, "unit-scaling-core-functional-transformer-residual-scaling-rule"]], "unit_scaling.formats": [[41, "module-unit_scaling.formats"]], "unit_scaling.formats.FPFormat": [[42, "unit-scaling-formats-fpformat"]], "unit_scaling.formats.format_to_tuple": [[43, "unit-scaling-formats-format-to-tuple"]], "unit_scaling.formats.tuple_to_format": [[44, "unit-scaling-formats-tuple-to-format"]], "unit_scaling.functional": [[45, "module-unit_scaling.functional"]], "unit_scaling.functional.add": [[46, "unit-scaling-functional-add"]], "unit_scaling.functional.cross_entropy": [[47, "unit-scaling-functional-cross-entropy"]], "unit_scaling.functional.dropout": [[48, "unit-scaling-functional-dropout"]], "unit_scaling.functional.embedding": [[49, "unit-scaling-functional-embedding"]], "unit_scaling.functional.gelu": [[50, "unit-scaling-functional-gelu"]], "unit_scaling.functional.layer_norm": [[51, "unit-scaling-functional-layer-norm"]], "unit_scaling.functional.linear": [[52, "unit-scaling-functional-linear"]], "unit_scaling.functional.linear_readout": [[53, "unit-scaling-functional-linear-readout"]], "unit_scaling.functional.matmul": [[54, "unit-scaling-functional-matmul"]], "unit_scaling.functional.mse_loss": [[55, "unit-scaling-functional-mse-loss"]], "unit_scaling.functional.residual_add": [[56, "unit-scaling-functional-residual-add"]], "unit_scaling.functional.residual_apply": [[57, "unit-scaling-functional-residual-apply"]], "unit_scaling.functional.residual_split": [[58, "unit-scaling-functional-residual-split"]], "unit_scaling.functional.rms_norm": [[59, "unit-scaling-functional-rms-norm"]], "unit_scaling.functional.scaled_dot_product_attention": [[60, "unit-scaling-functional-scaled-dot-product-attention"]], "unit_scaling.functional.silu": [[61, "unit-scaling-functional-silu"]], "unit_scaling.functional.silu_glu": [[62, "unit-scaling-functional-silu-glu"]], "unit_scaling.functional.softmax": [[63, "unit-scaling-functional-softmax"]], "unit_scaling.optim": [[64, "module-unit_scaling.optim"]], "unit_scaling.optim.Adam": [[65, "unit-scaling-optim-adam"]], "unit_scaling.optim.AdamW": [[66, "unit-scaling-optim-adamw"]], "unit_scaling.optim.SGD": [[67, "unit-scaling-optim-sgd"]], "unit_scaling.optim.lr_scale_for_depth": [[68, "unit-scaling-optim-lr-scale-for-depth"]], "unit_scaling.optim.lr_scale_func_adam": [[69, "unit-scaling-optim-lr-scale-func-adam"]], "unit_scaling.optim.lr_scale_func_sgd": [[70, "unit-scaling-optim-lr-scale-func-sgd"]], "unit_scaling.optim.scaled_parameters": [[71, "unit-scaling-optim-scaled-parameters"]], "unit_scaling.parameter": [[72, "module-unit_scaling.parameter"]], "unit_scaling.parameter.OrderedDict": [[73, "unit-scaling-parameter-ordereddict"]], "unit_scaling.parameter.Parameter": [[74, "unit-scaling-parameter-parameter"]], "unit_scaling.parameter.ParameterData": [[75, "unit-scaling-parameter-parameterdata"]], "unit_scaling.parameter.Protocol": [[76, "unit-scaling-parameter-protocol"]], "unit_scaling.parameter.Tensor": [[77, "unit-scaling-parameter-tensor"]], "unit_scaling.parameter.has_parameter_data": [[78, "unit-scaling-parameter-has-parameter-data"]], "unit_scaling.scale": [[79, "module-unit_scaling.scale"]], "unit_scaling.scale.scale_bwd": [[80, "unit-scaling-scale-scale-bwd"]], "unit_scaling.scale.scale_fwd": [[81, "unit-scaling-scale-scale-fwd"]], "unit_scaling.transformer_residual_scaling_rule": [[82, "unit-scaling-transformer-residual-scaling-rule"]], "unit_scaling.transforms": [[83, "module-unit_scaling.transforms"]], "unit_scaling.transforms.Metrics": [[84, "unit-scaling-transforms-metrics"]], "unit_scaling.transforms.compile": [[85, "unit-scaling-transforms-compile"]], "unit_scaling.transforms.prune_non_float_tensors": [[86, "unit-scaling-transforms-prune-non-float-tensors"]], "unit_scaling.transforms.prune_same_scale_tensors": [[87, "unit-scaling-transforms-prune-same-scale-tensors"]], "unit_scaling.transforms.prune_selected_nodes": [[88, "unit-scaling-transforms-prune-selected-nodes"]], "unit_scaling.transforms.simulate_format": [[89, "unit-scaling-transforms-simulate-format"]], "unit_scaling.transforms.simulate_fp8": [[90, "unit-scaling-transforms-simulate-fp8"]], "unit_scaling.transforms.track_scales": [[91, "unit-scaling-transforms-track-scales"]], "unit_scaling.transforms.unit_scale": [[92, "unit-scaling-transforms-unit-scale"]], "unit_scaling.transforms.utils": [[93, "module-unit_scaling.transforms.utils"]], "unit_scaling.transforms.utils.apply_transform": [[94, "unit-scaling-transforms-utils-apply-transform"]], "unit_scaling.transforms.utils.patch_to_expand_modules": [[95, "unit-scaling-transforms-utils-patch-to-expand-modules"]], "unit_scaling.transforms.utils.replace_node_with_function": [[96, "unit-scaling-transforms-utils-replace-node-with-function"]], "unit_scaling.transforms.utils.torch_nn_modules_to_user_modules": [[97, "unit-scaling-transforms-utils-torch-nn-modules-to-user-modules"]], "unit_scaling.utils": [[98, "module-unit_scaling.utils"]], "unit_scaling.utils.ScalePair": [[99, "unit-scaling-utils-scalepair"]], "unit_scaling.utils.ScaleTracker": [[100, "unit-scaling-utils-scaletracker"]], "unit_scaling.utils.ScaleTrackingInterpreter": [[101, "unit-scaling-utils-scaletrackinginterpreter"]], "unit_scaling.utils.analyse_module": [[102, "unit-scaling-utils-analyse-module"]], "unit_scaling.visualiser": [[103, "unit-scaling-visualiser"]], "Unit Scaling": [[104, "unit-scaling"]], "Getting Started": [[104, "getting-started"]], "Contents": [[104, null]], "Limitations": [[105, "limitations"]], "Almost-scaled dot-product attention": [[106, "almost-scaled-dot-product-attention"]], "Where does (d_{seq}/e)^{1/2} come from?": [[106, "where-does-d-seq-e-1-2-come-from"]], "Does it work? \u2026No!": [[106, "does-it-work-no"]], "Conclusion": [[106, "conclusion"]], "User guide": [[107, "user-guide"]], "How to unit-scale a model": [[107, "how-to-unit-scale-a-model"]], "Key considerations for unit scaling": [[107, "key-considerations-for-unit-scaling"]], "Optimising unit-scaled models": [[107, "optimising-unit-scaled-models"]]}, "indexentries": {"module": [[3, "module-unit_scaling"], [21, "module-unit_scaling.analysis"], [26, "module-unit_scaling.constraints"], [35, "module-unit_scaling.core"], [36, "module-unit_scaling.core.functional"], [41, "module-unit_scaling.formats"], [45, "module-unit_scaling.functional"], [64, "module-unit_scaling.optim"], [72, "module-unit_scaling.parameter"], [79, "module-unit_scaling.scale"], [83, "module-unit_scaling.transforms"], [93, "module-unit_scaling.transforms.utils"], [98, "module-unit_scaling.utils"]], "unit_scaling": [[3, "module-unit_scaling"]], "crossentropyloss (class in unit_scaling)": [[4, "unit_scaling.CrossEntropyLoss"]], "depthmodulelist (class in unit_scaling)": [[5, "unit_scaling.DepthModuleList"]], "append() (unit_scaling.depthmodulelist method)": [[5, "unit_scaling.DepthModuleList.append"]], "extend() (unit_scaling.depthmodulelist method)": [[5, "unit_scaling.DepthModuleList.extend"]], "insert() (unit_scaling.depthmodulelist method)": [[5, "unit_scaling.DepthModuleList.insert"]], "depthsequential (class in unit_scaling)": [[6, "unit_scaling.DepthSequential"]], "append() (unit_scaling.depthsequential method)": [[6, "unit_scaling.DepthSequential.append"]], "dropout (class in unit_scaling)": [[7, "unit_scaling.Dropout"]], "embedding (class in unit_scaling)": [[8, "unit_scaling.Embedding"]], "from_pretrained() (unit_scaling.embedding class method)": [[8, "unit_scaling.Embedding.from_pretrained"]], "weight (unit_scaling.embedding attribute)": [[8, "unit_scaling.Embedding.weight"]], "gelu (class in unit_scaling)": [[9, "unit_scaling.GELU"]], "layernorm (class in unit_scaling)": [[10, "unit_scaling.LayerNorm"]], "bias (unit_scaling.layernorm attribute)": [[10, "unit_scaling.LayerNorm.bias"]], "weight (unit_scaling.layernorm attribute)": [[10, "unit_scaling.LayerNorm.weight"]], "linear (class in unit_scaling)": [[11, "unit_scaling.Linear"]], "bias (unit_scaling.linear attribute)": [[11, "unit_scaling.Linear.bias"]], "weight (unit_scaling.linear attribute)": [[11, "unit_scaling.Linear.weight"]], "linearreadout (class in unit_scaling)": [[12, "unit_scaling.LinearReadout"]], "bias (unit_scaling.linearreadout attribute)": [[12, "unit_scaling.LinearReadout.bias"]], "weight (unit_scaling.linearreadout attribute)": [[12, "unit_scaling.LinearReadout.weight"]], "mhsa (class in unit_scaling)": [[13, "unit_scaling.MHSA"]], "mlp (class in unit_scaling)": [[14, "unit_scaling.MLP"]], "parameter() (in module unit_scaling)": [[15, "unit_scaling.Parameter"]], "rmsnorm (class in unit_scaling)": [[16, "unit_scaling.RMSNorm"]], "weight (unit_scaling.rmsnorm attribute)": [[16, "unit_scaling.RMSNorm.weight"]], "silu (class in unit_scaling)": [[17, "unit_scaling.SiLU"]], "softmax (class in unit_scaling)": [[18, "unit_scaling.Softmax"]], "transformerdecoder (class in unit_scaling)": [[19, "unit_scaling.TransformerDecoder"]], "append() (unit_scaling.transformerdecoder method)": [[19, "unit_scaling.TransformerDecoder.append"]], "transformerlayer (class in unit_scaling)": [[20, "unit_scaling.TransformerLayer"]], "unit_scaling.analysis": [[21, "module-unit_scaling.analysis"]], "example_batch() (in module unit_scaling.analysis)": [[22, "unit_scaling.analysis.example_batch"]], "graph_to_dataframe() (in module unit_scaling.analysis)": [[23, "unit_scaling.analysis.graph_to_dataframe"]], "plot() (in module unit_scaling.analysis)": [[24, "unit_scaling.analysis.plot"]], "visualiser() (in module unit_scaling.analysis)": [[25, "unit_scaling.analysis.visualiser"]], "unit_scaling.constraints": [[26, "module-unit_scaling.constraints"]], "amean() (in module unit_scaling.constraints)": [[27, "unit_scaling.constraints.amean"]], "apply_constraint() (in module unit_scaling.constraints)": [[28, "unit_scaling.constraints.apply_constraint"]], "gmean() (in module unit_scaling.constraints)": [[29, "unit_scaling.constraints.gmean"]], "hmean() (in module unit_scaling.constraints)": [[30, "unit_scaling.constraints.hmean"]], "to_grad_input_scale() (in module unit_scaling.constraints)": [[31, "unit_scaling.constraints.to_grad_input_scale"]], "to_left_grad_scale() (in module unit_scaling.constraints)": [[32, "unit_scaling.constraints.to_left_grad_scale"]], "to_output_scale() (in module unit_scaling.constraints)": [[33, "unit_scaling.constraints.to_output_scale"]], "to_right_grad_scale() (in module unit_scaling.constraints)": [[34, "unit_scaling.constraints.to_right_grad_scale"]], "unit_scaling.core": [[35, "module-unit_scaling.core"]], "unit_scaling.core.functional": [[36, "module-unit_scaling.core.functional"]], "logarithmic_interpolation() (in module unit_scaling.core.functional)": [[37, "unit_scaling.core.functional.logarithmic_interpolation"]], "rms() (in module unit_scaling.core.functional)": [[38, "unit_scaling.core.functional.rms"]], "scale_elementwise() (in module unit_scaling.core.functional)": [[39, "unit_scaling.core.functional.scale_elementwise"]], "transformer_residual_scaling_rule() (in module unit_scaling.core.functional)": [[40, "unit_scaling.core.functional.transformer_residual_scaling_rule"]], "unit_scaling.formats": [[41, "module-unit_scaling.formats"]], "fpformat (class in unit_scaling.formats)": [[42, "unit_scaling.formats.FPFormat"]], "bits (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.bits"]], "max_absolute_value (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.max_absolute_value"]], "min_absolute_normal (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.min_absolute_normal"]], "min_absolute_subnormal (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.min_absolute_subnormal"]], "quantise() (unit_scaling.formats.fpformat method)": [[42, "unit_scaling.formats.FPFormat.quantise"]], "quantise_bwd() (unit_scaling.formats.fpformat method)": [[42, "unit_scaling.formats.FPFormat.quantise_bwd"]], "quantise_fwd() (unit_scaling.formats.fpformat method)": [[42, "unit_scaling.formats.FPFormat.quantise_fwd"]], "format_to_tuple() (in module unit_scaling.formats)": [[43, "unit_scaling.formats.format_to_tuple"]], "tuple_to_format() (in module unit_scaling.formats)": [[44, "unit_scaling.formats.tuple_to_format"]], "unit_scaling.functional": [[45, "module-unit_scaling.functional"]], "add() (in module unit_scaling.functional)": [[46, "unit_scaling.functional.add"]], "cross_entropy() (in module unit_scaling.functional)": [[47, "unit_scaling.functional.cross_entropy"]], "dropout() (in module unit_scaling.functional)": [[48, "unit_scaling.functional.dropout"]], "embedding() (in module unit_scaling.functional)": [[49, "unit_scaling.functional.embedding"]], "gelu() (in module unit_scaling.functional)": [[50, "unit_scaling.functional.gelu"]], "layer_norm() (in module unit_scaling.functional)": [[51, "unit_scaling.functional.layer_norm"]], "linear() (in module unit_scaling.functional)": [[52, "unit_scaling.functional.linear"]], "linear_readout() (in module unit_scaling.functional)": [[53, "unit_scaling.functional.linear_readout"]], "matmul() (in module unit_scaling.functional)": [[54, "unit_scaling.functional.matmul"]], "mse_loss() (in module unit_scaling.functional)": [[55, "unit_scaling.functional.mse_loss"]], "residual_add() (in module unit_scaling.functional)": [[56, "unit_scaling.functional.residual_add"]], "residual_apply() (in module unit_scaling.functional)": [[57, "unit_scaling.functional.residual_apply"]], "residual_split() (in module unit_scaling.functional)": [[58, "unit_scaling.functional.residual_split"]], "rms_norm() (in module unit_scaling.functional)": [[59, "unit_scaling.functional.rms_norm"]], "scaled_dot_product_attention() (in module unit_scaling.functional)": [[60, "unit_scaling.functional.scaled_dot_product_attention"]], "silu() (in module unit_scaling.functional)": [[61, "unit_scaling.functional.silu"]], "silu_glu() (in module unit_scaling.functional)": [[62, "unit_scaling.functional.silu_glu"]], "softmax() (in module unit_scaling.functional)": [[63, "unit_scaling.functional.softmax"]], "unit_scaling.optim": [[64, "module-unit_scaling.optim"]], "adam (class in unit_scaling.optim)": [[65, "unit_scaling.optim.Adam"]], "add_param_group() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.add_param_group"]], "load_state_dict() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.load_state_dict"]], "register_load_state_dict_post_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_load_state_dict_post_hook"]], "register_load_state_dict_pre_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_load_state_dict_pre_hook"]], "register_state_dict_post_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_state_dict_post_hook"]], "register_state_dict_pre_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_state_dict_pre_hook"]], "register_step_post_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_step_post_hook"]], "register_step_pre_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_step_pre_hook"]], "state_dict() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.state_dict"]], "step() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.step"]], "zero_grad() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.zero_grad"]], "adamw (class in unit_scaling.optim)": [[66, "unit_scaling.optim.AdamW"]], "add_param_group() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.add_param_group"]], "load_state_dict() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.load_state_dict"]], "register_load_state_dict_post_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_load_state_dict_post_hook"]], "register_load_state_dict_pre_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_load_state_dict_pre_hook"]], "register_state_dict_post_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_state_dict_post_hook"]], "register_state_dict_pre_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_state_dict_pre_hook"]], "register_step_post_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_step_post_hook"]], "register_step_pre_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_step_pre_hook"]], "state_dict() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.state_dict"]], "step() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.step"]], "zero_grad() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.zero_grad"]], "sgd (class in unit_scaling.optim)": [[67, "unit_scaling.optim.SGD"]], "add_param_group() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.add_param_group"]], "load_state_dict() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.load_state_dict"]], "register_load_state_dict_post_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_load_state_dict_post_hook"]], "register_load_state_dict_pre_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_load_state_dict_pre_hook"]], "register_state_dict_post_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_state_dict_post_hook"]], "register_state_dict_pre_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_state_dict_pre_hook"]], "register_step_post_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_step_post_hook"]], "register_step_pre_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_step_pre_hook"]], "state_dict() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.state_dict"]], "step() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.step"]], "zero_grad() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.zero_grad"]], "lr_scale_for_depth() (in module unit_scaling.optim)": [[68, "unit_scaling.optim.lr_scale_for_depth"]], "lr_scale_func_adam() (in module unit_scaling.optim)": [[69, "unit_scaling.optim.lr_scale_func_adam"]], "lr_scale_func_sgd() (in module unit_scaling.optim)": [[70, "unit_scaling.optim.lr_scale_func_sgd"]], "scaled_parameters() (in module unit_scaling.optim)": [[71, "unit_scaling.optim.scaled_parameters"]], "unit_scaling.parameter": [[72, "module-unit_scaling.parameter"]], "ordereddict (class in unit_scaling.parameter)": [[73, "unit_scaling.parameter.OrderedDict"]], "clear() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.clear"]], "copy() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.copy"]], "fromkeys() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.fromkeys"]], "get() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.get"]], "items() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.items"]], "keys() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.keys"]], "move_to_end() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.move_to_end"]], "pop() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.pop"]], "popitem() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.popitem"]], "setdefault() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.setdefault"]], "update() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.update"]], "values() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.values"]], "parameter() (in module unit_scaling.parameter)": [[74, "unit_scaling.parameter.Parameter"]], "parameterdata (class in unit_scaling.parameter)": [[75, "unit_scaling.parameter.ParameterData"]], "protocol (class in unit_scaling.parameter)": [[76, "unit_scaling.parameter.Protocol"]], "h (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.H"]], "t (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.T"]], "tensor (class in unit_scaling.parameter)": [[77, "unit_scaling.parameter.Tensor"]], "abs() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.abs"]], "abs_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.abs_"]], "absolute() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.absolute"]], "absolute_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.absolute_"]], "acos() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acos"]], "acos_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acos_"]], "acosh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acosh"]], "acosh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acosh_"]], "add() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.add"]], "add_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.add_"]], "addbmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addbmm"]], "addbmm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addbmm_"]], "addcdiv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcdiv"]], "addcdiv_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcdiv_"]], "addcmul() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcmul"]], "addcmul_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcmul_"]], "addmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmm"]], "addmm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmm_"]], "addmv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmv"]], "addmv_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmv_"]], "addr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addr"]], "addr_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addr_"]], "adjoint() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.adjoint"]], "align_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.align_as"]], "align_to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.align_to"]], "all() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.all"]], "allclose() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.allclose"]], "amax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.amax"]], "amin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.amin"]], "aminmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.aminmax"]], "angle() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.angle"]], "any() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.any"]], "apply_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.apply_"]], "arccos() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccos"]], "arccos_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccos_"]], "arccosh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccosh"]], "arccosh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccosh_"]], "arcsin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsin"]], "arcsin_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsin_"]], "arcsinh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsinh"]], "arcsinh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsinh_"]], "arctan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan"]], "arctan2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan2"]], "arctan2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan2_"]], "arctan_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan_"]], "arctanh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctanh"]], "arctanh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctanh_"]], "argmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argmax"]], "argmin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argmin"]], "argsort() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argsort"]], "argwhere() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argwhere"]], "as_strided() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_strided"]], "as_strided_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_strided_"]], "as_strided_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_strided_scatter"]], "as_subclass() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_subclass"]], "asin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asin"]], "asin_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asin_"]], "asinh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asinh"]], "asinh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asinh_"]], "atan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan"]], "atan2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan2"]], "atan2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan2_"]], "atan_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan_"]], "atanh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atanh"]], "atanh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atanh_"]], "backward() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.backward"]], "baddbmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.baddbmm"]], "baddbmm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.baddbmm_"]], "bernoulli() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bernoulli"]], "bernoulli_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bernoulli_"]], "bfloat16() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bfloat16"]], "bincount() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bincount"]], "bitwise_and() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_and"]], "bitwise_and_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_and_"]], "bitwise_left_shift() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_left_shift"]], "bitwise_left_shift_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_left_shift_"]], "bitwise_not() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_not"]], "bitwise_not_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_not_"]], "bitwise_or() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_or"]], "bitwise_or_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_or_"]], "bitwise_right_shift() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_right_shift"]], "bitwise_right_shift_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_right_shift_"]], "bitwise_xor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_xor"]], "bitwise_xor_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_xor_"]], "bmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bmm"]], "bool() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bool"]], "broadcast_to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.broadcast_to"]], "byte() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.byte"]], "cauchy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cauchy_"]], "cdouble() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cdouble"]], "ceil() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ceil"]], "ceil_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ceil_"]], "cfloat() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cfloat"]], "chalf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.chalf"]], "char() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.char"]], "cholesky() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cholesky"]], "cholesky_inverse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cholesky_inverse"]], "cholesky_solve() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cholesky_solve"]], "chunk() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.chunk"]], "clamp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clamp"]], "clamp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clamp_"]], "clip() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clip"]], "clip_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clip_"]], "clone() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clone"]], "coalesce() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.coalesce"]], "col_indices() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.col_indices"]], "conj() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.conj"]], "conj_physical() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.conj_physical"]], "conj_physical_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.conj_physical_"]], "contiguous() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.contiguous"]], "copy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.copy_"]], "copysign() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.copysign"]], "copysign_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.copysign_"]], "corrcoef() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.corrcoef"]], "cos() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cos"]], "cos_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cos_"]], "cosh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cosh"]], "cosh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cosh_"]], "count_nonzero() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.count_nonzero"]], "cov() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cov"]], "cpu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cpu"]], "cross() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cross"]], "crow_indices() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.crow_indices"]], "cuda() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cuda"]], "cummax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cummax"]], "cummin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cummin"]], "cumprod() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumprod"]], "cumprod_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumprod_"]], "cumsum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumsum"]], "cumsum_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumsum_"]], "data_ptr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.data_ptr"]], "deg2rad() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.deg2rad"]], "deg2rad_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.deg2rad_"]], "dense_dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dense_dim"]], "dequantize() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dequantize"]], "det() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.det"]], "detach() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.detach"]], "detach_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.detach_"]], "device (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.device"]], "diag() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diag"]], "diag_embed() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diag_embed"]], "diagflat() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diagflat"]], "diagonal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diagonal"]], "diagonal_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diagonal_scatter"]], "diff() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diff"]], "digamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.digamma"]], "digamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.digamma_"]], "dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dim"]], "dim_order() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dim_order"]], "dist() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dist"]], "div() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.div"]], "div_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.div_"]], "divide() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.divide"]], "divide_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.divide_"]], "dot() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dot"]], "double() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.double"]], "dsplit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dsplit"]], "element_size() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.element_size"]], "eq() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.eq"]], "eq_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.eq_"]], "equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.equal"]], "erf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erf"]], "erf_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erf_"]], "erfc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfc"]], "erfc_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfc_"]], "erfinv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfinv"]], "erfinv_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfinv_"]], "exp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp"]], "exp2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp2"]], "exp2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp2_"]], "exp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp_"]], "expand() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expand"]], "expand_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expand_as"]], "expm1() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expm1"]], "expm1_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expm1_"]], "exponential_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exponential_"]], "fill_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fill_"]], "fill_diagonal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fill_diagonal_"]], "fix() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fix"]], "fix_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fix_"]], "flatten() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.flatten"]], "flip() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.flip"]], "fliplr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fliplr"]], "flipud() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.flipud"]], "float() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.float"]], "float_power() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.float_power"]], "float_power_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.float_power_"]], "floor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor"]], "floor_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor_"]], "floor_divide() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor_divide"]], "floor_divide_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor_divide_"]], "fmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmax"]], "fmin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmin"]], "fmod() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmod"]], "fmod_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmod_"]], "frac() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.frac"]], "frac_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.frac_"]], "frexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.frexp"]], "gather() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gather"]], "gcd() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gcd"]], "gcd_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gcd_"]], "ge() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ge"]], "ge_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ge_"]], "geometric_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.geometric_"]], "geqrf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.geqrf"]], "ger() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ger"]], "get_device() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.get_device"]], "grad (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.grad"]], "greater() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater"]], "greater_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater_"]], "greater_equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater_equal"]], "greater_equal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater_equal_"]], "gt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gt"]], "gt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gt_"]], "half() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.half"]], "hardshrink() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hardshrink"]], "has_names() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.has_names"]], "heaviside() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.heaviside"]], "heaviside_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.heaviside_"]], "histc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.histc"]], "histogram() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.histogram"]], "hsplit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hsplit"]], "hypot() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hypot"]], "hypot_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hypot_"]], "i0() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.i0"]], "i0_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.i0_"]], "igamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igamma"]], "igamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igamma_"]], "igammac() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igammac"]], "igammac_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igammac_"]], "imag (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.imag"]], "index_add() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_add"]], "index_add_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_add_"]], "index_copy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_copy"]], "index_copy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_copy_"]], "index_fill() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_fill"]], "index_fill_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_fill_"]], "index_put() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_put"]], "index_put_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_put_"]], "index_reduce_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_reduce_"]], "index_select() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_select"]], "indices() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.indices"]], "inner() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.inner"]], "int() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.int"]], "int_repr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.int_repr"]], "inverse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.inverse"]], "ipu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ipu"]], "is_coalesced() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_coalesced"]], "is_complex() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_complex"]], "is_conj() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_conj"]], "is_contiguous() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_contiguous"]], "is_cpu (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_cpu"]], "is_cuda (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_cuda"]], "is_floating_point() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_floating_point"]], "is_inference() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_inference"]], "is_ipu (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_ipu"]], "is_leaf (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_leaf"]], "is_meta (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_meta"]], "is_mps (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_mps"]], "is_neg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_neg"]], "is_pinned() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_pinned"]], "is_quantized (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_quantized"]], "is_set_to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_set_to"]], "is_shared() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_shared"]], "is_signed() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_signed"]], "is_sparse (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_sparse"]], "is_sparse_csr (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_sparse_csr"]], "is_xla (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_xla"]], "is_xpu (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_xpu"]], "isclose() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isclose"]], "isfinite() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isfinite"]], "isinf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isinf"]], "isnan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isnan"]], "isneginf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isneginf"]], "isposinf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isposinf"]], "isreal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isreal"]], "istft() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.istft"]], "item() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.item"]], "itemsize (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.itemsize"]], "kron() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.kron"]], "kthvalue() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.kthvalue"]], "lcm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lcm"]], "lcm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lcm_"]], "ldexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ldexp"]], "ldexp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ldexp_"]], "le() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.le"]], "le_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.le_"]], "lerp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lerp"]], "lerp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lerp_"]], "less() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less"]], "less_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less_"]], "less_equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less_equal"]], "less_equal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less_equal_"]], "lgamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lgamma"]], "lgamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lgamma_"]], "log() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log"]], "log10() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log10"]], "log10_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log10_"]], "log1p() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log1p"]], "log1p_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log1p_"]], "log2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log2"]], "log2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log2_"]], "log_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log_"]], "log_normal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log_normal_"]], "logaddexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logaddexp"]], "logaddexp2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logaddexp2"]], "logcumsumexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logcumsumexp"]], "logdet() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logdet"]], "logical_and() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_and"]], "logical_and_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_and_"]], "logical_not() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_not"]], "logical_not_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_not_"]], "logical_or() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_or"]], "logical_or_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_or_"]], "logical_xor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_xor"]], "logical_xor_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_xor_"]], "logit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logit"]], "logit_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logit_"]], "logsumexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logsumexp"]], "long() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.long"]], "lt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lt"]], "lt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lt_"]], "lu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lu"]], "lu_solve() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lu_solve"]], "mh (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.mH"]], "mt (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.mT"]], "map_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.map_"]], "masked_fill() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_fill"]], "masked_fill_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_fill_"]], "masked_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_scatter"]], "masked_scatter_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_scatter_"]], "masked_select() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_select"]], "matmul() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.matmul"]], "matrix_exp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.matrix_exp"]], "matrix_power() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.matrix_power"]], "max() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.max"]], "maximum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.maximum"]], "mean() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mean"]], "median() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.median"]], "min() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.min"]], "minimum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.minimum"]], "mm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mm"]], "mode() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mode"]], "module_load() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.module_load"]], "moveaxis() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.moveaxis"]], "movedim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.movedim"]], "msort() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.msort"]], "mul() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mul"]], "mul_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mul_"]], "multinomial() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.multinomial"]], "multiply() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.multiply"]], "multiply_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.multiply_"]], "mv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mv"]], "mvlgamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mvlgamma"]], "mvlgamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mvlgamma_"]], "names (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.names"]], "nan_to_num() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nan_to_num"]], "nan_to_num_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nan_to_num_"]], "nanmean() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nanmean"]], "nanmedian() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nanmedian"]], "nanquantile() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nanquantile"]], "nansum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nansum"]], "narrow() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.narrow"]], "narrow_copy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.narrow_copy"]], "nbytes (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.nbytes"]], "ndim (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.ndim"]], "ndimension() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ndimension"]], "ne() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ne"]], "ne_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ne_"]], "neg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.neg"]], "neg_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.neg_"]], "negative() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.negative"]], "negative_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.negative_"]], "nelement() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nelement"]], "new_empty() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_empty"]], "new_empty_strided() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_empty_strided"]], "new_full() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_full"]], "new_ones() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_ones"]], "new_tensor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_tensor"]], "new_zeros() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_zeros"]], "nextafter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nextafter"]], "nextafter_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nextafter_"]], "nonzero() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nonzero"]], "nonzero_static() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nonzero_static"]], "norm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.norm"]], "normal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.normal_"]], "not_equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.not_equal"]], "not_equal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.not_equal_"]], "numel() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.numel"]], "numpy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.numpy"]], "orgqr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.orgqr"]], "ormqr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ormqr"]], "outer() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.outer"]], "permute() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.permute"]], "pin_memory() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pin_memory"]], "pinverse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pinverse"]], "polygamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.polygamma"]], "polygamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.polygamma_"]], "positive() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.positive"]], "pow() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pow"]], "pow_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pow_"]], "prod() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.prod"]], "put() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.put"]], "put_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.put_"]], "q_per_channel_axis() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_per_channel_axis"]], "q_per_channel_scales() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_per_channel_scales"]], "q_per_channel_zero_points() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_per_channel_zero_points"]], "q_scale() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_scale"]], "q_zero_point() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_zero_point"]], "qr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.qr"]], "qscheme() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.qscheme"]], "quantile() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.quantile"]], "rad2deg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rad2deg"]], "rad2deg_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rad2deg_"]], "random_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.random_"]], "ravel() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ravel"]], "real (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.real"]], "reciprocal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reciprocal"]], "reciprocal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reciprocal_"]], "record_stream() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.record_stream"]], "refine_names() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.refine_names"]], "register_hook() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.register_hook"]], "register_post_accumulate_grad_hook() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.register_post_accumulate_grad_hook"]], "remainder() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.remainder"]], "remainder_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.remainder_"]], "rename() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rename"]], "rename_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rename_"]], "renorm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.renorm"]], "renorm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.renorm_"]], "repeat() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.repeat"]], "repeat_interleave() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.repeat_interleave"]], "requires_grad (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.requires_grad"]], "requires_grad_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.requires_grad_"]], "reshape() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reshape"]], "reshape_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reshape_as"]], "resize_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resize_"]], "resize_as_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resize_as_"]], "resolve_conj() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resolve_conj"]], "resolve_neg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resolve_neg"]], "retain_grad() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.retain_grad"]], "retains_grad (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.retains_grad"]], "roll() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.roll"]], "rot90() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rot90"]], "round() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.round"]], "round_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.round_"]], "rsqrt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rsqrt"]], "rsqrt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rsqrt_"]], "scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter"]], "scatter_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_"]], "scatter_add() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_add"]], "scatter_add_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_add_"]], "scatter_reduce() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_reduce"]], "scatter_reduce_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_reduce_"]], "select() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.select"]], "select_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.select_scatter"]], "set_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.set_"]], "sgn() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sgn"]], "sgn_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sgn_"]], "shape (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.shape"]], "share_memory_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.share_memory_"]], "short() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.short"]], "sigmoid() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sigmoid"]], "sigmoid_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sigmoid_"]], "sign() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sign"]], "sign_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sign_"]], "signbit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.signbit"]], "sin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sin"]], "sin_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sin_"]], "sinc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinc"]], "sinc_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinc_"]], "sinh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinh"]], "sinh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinh_"]], "size() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.size"]], "slice_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.slice_scatter"]], "slogdet() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.slogdet"]], "smm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.smm"]], "softmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.softmax"]], "sort() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sort"]], "sparse_dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_dim"]], "sparse_mask() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_mask"]], "sparse_resize_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_resize_"]], "sparse_resize_and_clear_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_resize_and_clear_"]], "split() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.split"]], "sqrt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sqrt"]], "sqrt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sqrt_"]], "square() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.square"]], "square_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.square_"]], "squeeze() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.squeeze"]], "squeeze_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.squeeze_"]], "sspaddmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sspaddmm"]], "std() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.std"]], "stft() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.stft"]], "storage() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.storage"]], "storage_offset() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.storage_offset"]], "storage_type() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.storage_type"]], "stride() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.stride"]], "sub() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sub"]], "sub_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sub_"]], "subtract() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.subtract"]], "subtract_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.subtract_"]], "sum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sum"]], "sum_to_size() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sum_to_size"]], "svd() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.svd"]], "swapaxes() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapaxes"]], "swapaxes_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapaxes_"]], "swapdims() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapdims"]], "swapdims_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapdims_"]], "t() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.t"]], "t_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.t_"]], "take() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.take"]], "take_along_dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.take_along_dim"]], "tan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tan"]], "tan_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tan_"]], "tanh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tanh"]], "tanh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tanh_"]], "tensor_split() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tensor_split"]], "tile() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tile"]], "to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to"]], "to_dense() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_dense"]], "to_mkldnn() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_mkldnn"]], "to_padded_tensor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_padded_tensor"]], "to_sparse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse"]], "to_sparse_bsc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_bsc"]], "to_sparse_bsr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_bsr"]], "to_sparse_coo() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_coo"]], "to_sparse_csc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_csc"]], "to_sparse_csr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_csr"]], "tolist() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tolist"]], "topk() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.topk"]], "trace() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.trace"]], "transpose() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.transpose"]], "transpose_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.transpose_"]], "triangular_solve() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.triangular_solve"]], "tril() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tril"]], "tril_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tril_"]], "triu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.triu"]], "triu_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.triu_"]], "true_divide() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.true_divide"]], "true_divide_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.true_divide_"]], "trunc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.trunc"]], "trunc_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.trunc_"]], "type() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.type"]], "type_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.type_as"]], "unbind() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unbind"]], "unflatten() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unflatten"]], "unfold() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unfold"]], "uniform_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.uniform_"]], "unique() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unique"]], "unique_consecutive() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unique_consecutive"]], "unsafe_chunk() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsafe_chunk"]], "unsafe_split() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsafe_split"]], "unsqueeze() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsqueeze"]], "unsqueeze_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsqueeze_"]], "untyped_storage() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.untyped_storage"]], "values() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.values"]], "var() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.var"]], "vdot() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.vdot"]], "view() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.view"]], "view_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.view_as"]], "vsplit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.vsplit"]], "where() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.where"]], "xlogy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.xlogy"]], "xlogy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.xlogy_"]], "xpu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.xpu"]], "zero_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.zero_"]], "has_parameter_data() (in module unit_scaling.parameter)": [[78, "unit_scaling.parameter.has_parameter_data"]], "unit_scaling.scale": [[79, "module-unit_scaling.scale"]], "scale_bwd() (in module unit_scaling.scale)": [[80, "unit_scaling.scale.scale_bwd"]], "scale_fwd() (in module unit_scaling.scale)": [[81, "unit_scaling.scale.scale_fwd"]], "transformer_residual_scaling_rule() (in module unit_scaling)": [[82, "unit_scaling.transformer_residual_scaling_rule"]], "unit_scaling.transforms": [[83, "module-unit_scaling.transforms"]], "metrics (class in unit_scaling.transforms)": [[84, "unit_scaling.transforms.Metrics"]], "metrics.data (class in unit_scaling.transforms)": [[84, "unit_scaling.transforms.Metrics.Data"]], "compile() (in module unit_scaling.transforms)": [[85, "unit_scaling.transforms.compile"]], "prune_non_float_tensors() (in module unit_scaling.transforms)": [[86, "unit_scaling.transforms.prune_non_float_tensors"]], "prune_same_scale_tensors() (in module unit_scaling.transforms)": [[87, "unit_scaling.transforms.prune_same_scale_tensors"]], "prune_selected_nodes() (in module unit_scaling.transforms)": [[88, "unit_scaling.transforms.prune_selected_nodes"]], "simulate_format() (in module unit_scaling.transforms)": [[89, "unit_scaling.transforms.simulate_format"]], "simulate_fp8() (in module unit_scaling.transforms)": [[90, "unit_scaling.transforms.simulate_fp8"]], "track_scales() (in module unit_scaling.transforms)": [[91, "unit_scaling.transforms.track_scales"]], "unit_scale() (in module unit_scaling.transforms)": [[92, "unit_scaling.transforms.unit_scale"]], "unit_scaling.transforms.utils": [[93, "module-unit_scaling.transforms.utils"]], "apply_transform() (in module unit_scaling.transforms.utils)": [[94, "unit_scaling.transforms.utils.apply_transform"]], "patch_to_expand_modules() (in module unit_scaling.transforms.utils)": [[95, "unit_scaling.transforms.utils.patch_to_expand_modules"]], "replace_node_with_function() (in module unit_scaling.transforms.utils)": [[96, "unit_scaling.transforms.utils.replace_node_with_function"]], "torch_nn_modules_to_user_modules() (in module unit_scaling.transforms.utils)": [[97, "unit_scaling.transforms.utils.torch_nn_modules_to_user_modules"]], "unit_scaling.utils": [[98, "module-unit_scaling.utils"]], "scalepair (class in unit_scaling.utils)": [[99, "unit_scaling.utils.ScalePair"]], "scaletracker (class in unit_scaling.utils)": [[100, "unit_scaling.utils.ScaleTracker"]], "backward() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.backward"]], "jvp() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.jvp"]], "mark_dirty() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.mark_dirty"]], "mark_non_differentiable() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.mark_non_differentiable"]], "save_for_backward() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.save_for_backward"]], "save_for_forward() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.save_for_forward"]], "set_materialize_grads() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.set_materialize_grads"]], "setup_context() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.setup_context"]], "vjp() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.vjp"]], "vmap() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.vmap"]], "scaletrackinginterpreter (class in unit_scaling.utils)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter"]], "boxed_run() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.boxed_run"]], "call_function() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.call_function"]], "call_method() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.call_method"]], "call_module() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.call_module"]], "fetch_args_kwargs_from_env() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.fetch_args_kwargs_from_env"]], "fetch_attr() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.fetch_attr"]], "get_attr() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.get_attr"]], "map_nodes_to_values() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.map_nodes_to_values"]], "output() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.output"]], "placeholder() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.placeholder"]], "run() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.run"]], "run_node() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.run_node"]], "analyse_module() (in module unit_scaling.utils)": [[102, "unit_scaling.utils.analyse_module"]], "visualiser() (in module unit_scaling)": [[103, "unit_scaling.visualiser"]]}})
\ No newline at end of file
+Search.setIndex({"docnames": ["api_reference", "blog", "development", "generated/unit_scaling", "generated/unit_scaling.CrossEntropyLoss", "generated/unit_scaling.DepthModuleList", "generated/unit_scaling.DepthSequential", "generated/unit_scaling.Dropout", "generated/unit_scaling.Embedding", "generated/unit_scaling.GELU", "generated/unit_scaling.LayerNorm", "generated/unit_scaling.Linear", "generated/unit_scaling.LinearReadout", "generated/unit_scaling.MHSA", "generated/unit_scaling.MLP", "generated/unit_scaling.Parameter", "generated/unit_scaling.RMSNorm", "generated/unit_scaling.SiLU", "generated/unit_scaling.Softmax", "generated/unit_scaling.TransformerDecoder", "generated/unit_scaling.TransformerLayer", "generated/unit_scaling.analysis", "generated/unit_scaling.analysis.example_batch", "generated/unit_scaling.analysis.graph_to_dataframe", "generated/unit_scaling.analysis.plot", "generated/unit_scaling.analysis.visualiser", "generated/unit_scaling.constraints", "generated/unit_scaling.constraints.amean", "generated/unit_scaling.constraints.apply_constraint", "generated/unit_scaling.constraints.gmean", "generated/unit_scaling.constraints.hmean", "generated/unit_scaling.constraints.to_grad_input_scale", "generated/unit_scaling.constraints.to_left_grad_scale", "generated/unit_scaling.constraints.to_output_scale", "generated/unit_scaling.constraints.to_right_grad_scale", "generated/unit_scaling.core", "generated/unit_scaling.core.functional", "generated/unit_scaling.core.functional.logarithmic_interpolation", "generated/unit_scaling.core.functional.rms", "generated/unit_scaling.core.functional.scale_elementwise", "generated/unit_scaling.core.functional.transformer_residual_scaling_rule", "generated/unit_scaling.formats", "generated/unit_scaling.formats.FPFormat", "generated/unit_scaling.formats.format_to_tuple", "generated/unit_scaling.formats.tuple_to_format", "generated/unit_scaling.functional", "generated/unit_scaling.functional.add", "generated/unit_scaling.functional.cross_entropy", "generated/unit_scaling.functional.dropout", "generated/unit_scaling.functional.embedding", "generated/unit_scaling.functional.gelu", "generated/unit_scaling.functional.layer_norm", "generated/unit_scaling.functional.linear", "generated/unit_scaling.functional.linear_readout", "generated/unit_scaling.functional.matmul", "generated/unit_scaling.functional.mse_loss", "generated/unit_scaling.functional.residual_add", "generated/unit_scaling.functional.residual_apply", "generated/unit_scaling.functional.residual_split", "generated/unit_scaling.functional.rms_norm", "generated/unit_scaling.functional.scaled_dot_product_attention", "generated/unit_scaling.functional.silu", "generated/unit_scaling.functional.silu_glu", "generated/unit_scaling.functional.softmax", "generated/unit_scaling.optim", "generated/unit_scaling.optim.Adam", "generated/unit_scaling.optim.AdamW", "generated/unit_scaling.optim.SGD", "generated/unit_scaling.optim.lr_scale_for_depth", "generated/unit_scaling.optim.lr_scale_func_adam", "generated/unit_scaling.optim.lr_scale_func_sgd", "generated/unit_scaling.optim.scaled_parameters", "generated/unit_scaling.parameter", "generated/unit_scaling.parameter.OrderedDict", "generated/unit_scaling.parameter.Parameter", "generated/unit_scaling.parameter.ParameterData", "generated/unit_scaling.parameter.Protocol", "generated/unit_scaling.parameter.Tensor", "generated/unit_scaling.parameter.has_parameter_data", "generated/unit_scaling.scale", "generated/unit_scaling.scale.scale_bwd", "generated/unit_scaling.scale.scale_fwd", "generated/unit_scaling.transformer_residual_scaling_rule", "generated/unit_scaling.transforms", "generated/unit_scaling.transforms.Metrics", "generated/unit_scaling.transforms.compile", "generated/unit_scaling.transforms.prune_non_float_tensors", "generated/unit_scaling.transforms.prune_same_scale_tensors", "generated/unit_scaling.transforms.prune_selected_nodes", "generated/unit_scaling.transforms.simulate_format", "generated/unit_scaling.transforms.simulate_fp8", "generated/unit_scaling.transforms.track_scales", "generated/unit_scaling.transforms.unit_scale", "generated/unit_scaling.transforms.utils", "generated/unit_scaling.transforms.utils.apply_transform", "generated/unit_scaling.transforms.utils.patch_to_expand_modules", "generated/unit_scaling.transforms.utils.replace_node_with_function", "generated/unit_scaling.transforms.utils.torch_nn_modules_to_user_modules", "generated/unit_scaling.utils", "generated/unit_scaling.utils.ScalePair", "generated/unit_scaling.utils.ScaleTracker", "generated/unit_scaling.utils.ScaleTrackingInterpreter", "generated/unit_scaling.utils.analyse_module", "generated/unit_scaling.visualiser", "index", "limitations", "posts/almost_scaled_dot_product_attention", "user_guide"], "filenames": ["api_reference.rst", "blog.rst", "development.md", "generated/unit_scaling.rst", "generated/unit_scaling.CrossEntropyLoss.rst", "generated/unit_scaling.DepthModuleList.rst", "generated/unit_scaling.DepthSequential.rst", "generated/unit_scaling.Dropout.rst", "generated/unit_scaling.Embedding.rst", "generated/unit_scaling.GELU.rst", "generated/unit_scaling.LayerNorm.rst", "generated/unit_scaling.Linear.rst", "generated/unit_scaling.LinearReadout.rst", "generated/unit_scaling.MHSA.rst", "generated/unit_scaling.MLP.rst", "generated/unit_scaling.Parameter.rst", "generated/unit_scaling.RMSNorm.rst", "generated/unit_scaling.SiLU.rst", "generated/unit_scaling.Softmax.rst", "generated/unit_scaling.TransformerDecoder.rst", "generated/unit_scaling.TransformerLayer.rst", "generated/unit_scaling.analysis.rst", "generated/unit_scaling.analysis.example_batch.rst", "generated/unit_scaling.analysis.graph_to_dataframe.rst", "generated/unit_scaling.analysis.plot.rst", "generated/unit_scaling.analysis.visualiser.rst", "generated/unit_scaling.constraints.rst", "generated/unit_scaling.constraints.amean.rst", "generated/unit_scaling.constraints.apply_constraint.rst", "generated/unit_scaling.constraints.gmean.rst", "generated/unit_scaling.constraints.hmean.rst", "generated/unit_scaling.constraints.to_grad_input_scale.rst", "generated/unit_scaling.constraints.to_left_grad_scale.rst", "generated/unit_scaling.constraints.to_output_scale.rst", "generated/unit_scaling.constraints.to_right_grad_scale.rst", "generated/unit_scaling.core.rst", "generated/unit_scaling.core.functional.rst", "generated/unit_scaling.core.functional.logarithmic_interpolation.rst", "generated/unit_scaling.core.functional.rms.rst", "generated/unit_scaling.core.functional.scale_elementwise.rst", "generated/unit_scaling.core.functional.transformer_residual_scaling_rule.rst", "generated/unit_scaling.formats.rst", "generated/unit_scaling.formats.FPFormat.rst", "generated/unit_scaling.formats.format_to_tuple.rst", "generated/unit_scaling.formats.tuple_to_format.rst", "generated/unit_scaling.functional.rst", "generated/unit_scaling.functional.add.rst", "generated/unit_scaling.functional.cross_entropy.rst", "generated/unit_scaling.functional.dropout.rst", "generated/unit_scaling.functional.embedding.rst", "generated/unit_scaling.functional.gelu.rst", "generated/unit_scaling.functional.layer_norm.rst", "generated/unit_scaling.functional.linear.rst", "generated/unit_scaling.functional.linear_readout.rst", "generated/unit_scaling.functional.matmul.rst", "generated/unit_scaling.functional.mse_loss.rst", "generated/unit_scaling.functional.residual_add.rst", "generated/unit_scaling.functional.residual_apply.rst", "generated/unit_scaling.functional.residual_split.rst", "generated/unit_scaling.functional.rms_norm.rst", "generated/unit_scaling.functional.scaled_dot_product_attention.rst", "generated/unit_scaling.functional.silu.rst", "generated/unit_scaling.functional.silu_glu.rst", "generated/unit_scaling.functional.softmax.rst", "generated/unit_scaling.optim.rst", "generated/unit_scaling.optim.Adam.rst", "generated/unit_scaling.optim.AdamW.rst", "generated/unit_scaling.optim.SGD.rst", "generated/unit_scaling.optim.lr_scale_for_depth.rst", "generated/unit_scaling.optim.lr_scale_func_adam.rst", "generated/unit_scaling.optim.lr_scale_func_sgd.rst", "generated/unit_scaling.optim.scaled_parameters.rst", "generated/unit_scaling.parameter.rst", "generated/unit_scaling.parameter.OrderedDict.rst", "generated/unit_scaling.parameter.Parameter.rst", "generated/unit_scaling.parameter.ParameterData.rst", "generated/unit_scaling.parameter.Protocol.rst", "generated/unit_scaling.parameter.Tensor.rst", "generated/unit_scaling.parameter.has_parameter_data.rst", "generated/unit_scaling.scale.rst", "generated/unit_scaling.scale.scale_bwd.rst", "generated/unit_scaling.scale.scale_fwd.rst", "generated/unit_scaling.transformer_residual_scaling_rule.rst", "generated/unit_scaling.transforms.rst", "generated/unit_scaling.transforms.Metrics.rst", "generated/unit_scaling.transforms.compile.rst", "generated/unit_scaling.transforms.prune_non_float_tensors.rst", "generated/unit_scaling.transforms.prune_same_scale_tensors.rst", "generated/unit_scaling.transforms.prune_selected_nodes.rst", "generated/unit_scaling.transforms.simulate_format.rst", "generated/unit_scaling.transforms.simulate_fp8.rst", "generated/unit_scaling.transforms.track_scales.rst", "generated/unit_scaling.transforms.unit_scale.rst", "generated/unit_scaling.transforms.utils.rst", "generated/unit_scaling.transforms.utils.apply_transform.rst", "generated/unit_scaling.transforms.utils.patch_to_expand_modules.rst", "generated/unit_scaling.transforms.utils.replace_node_with_function.rst", "generated/unit_scaling.transforms.utils.torch_nn_modules_to_user_modules.rst", "generated/unit_scaling.utils.rst", "generated/unit_scaling.utils.ScalePair.rst", "generated/unit_scaling.utils.ScaleTracker.rst", "generated/unit_scaling.utils.ScaleTrackingInterpreter.rst", "generated/unit_scaling.utils.analyse_module.rst", "generated/unit_scaling.visualiser.rst", "index.rst", "limitations.rst", "posts/almost_scaled_dot_product_attention.md", "user_guide.rst"], "titles": ["5. API reference", "4. Unit Scaling blog", "2. Development", "5.1. unit_scaling", "5.1.4. unit_scaling.CrossEntropyLoss", "5.1.5. unit_scaling.DepthModuleList", "5.1.6. unit_scaling.DepthSequential", "5.1.7. unit_scaling.Dropout", "5.1.8. unit_scaling.Embedding", "5.1.9. unit_scaling.GELU", "5.1.10. unit_scaling.LayerNorm", "5.1.11. unit_scaling.Linear", "5.1.12. unit_scaling.LinearReadout", "5.1.13. unit_scaling.MHSA", "5.1.14. unit_scaling.MLP", "5.1.1. unit_scaling.Parameter", "5.1.15. unit_scaling.RMSNorm", "5.1.16. unit_scaling.SiLU", "5.1.17. unit_scaling.Softmax", "5.1.18. unit_scaling.TransformerDecoder", "5.1.19. unit_scaling.TransformerLayer", "5.2. unit_scaling.analysis", "5.2.1. unit_scaling.analysis.example_batch", "5.2.2. unit_scaling.analysis.graph_to_dataframe", "5.2.3. unit_scaling.analysis.plot", "5.2.4. unit_scaling.analysis.visualiser", "5.3. unit_scaling.constraints", "5.3.1. unit_scaling.constraints.amean", "5.3.2. unit_scaling.constraints.apply_constraint", "5.3.3. unit_scaling.constraints.gmean", "5.3.4. unit_scaling.constraints.hmean", "5.3.5. unit_scaling.constraints.to_grad_input_scale", "5.3.6. unit_scaling.constraints.to_left_grad_scale", "5.3.7. unit_scaling.constraints.to_output_scale", "5.3.8. unit_scaling.constraints.to_right_grad_scale", "5.1.20. unit_scaling.core", "5.1.20.1. unit_scaling.core.functional", "5.1.20.1.1. unit_scaling.core.functional.logarithmic_interpolation", "5.1.20.1.2. unit_scaling.core.functional.rms", "5.1.20.1.3. unit_scaling.core.functional.scale_elementwise", "5.1.20.1.4. unit_scaling.core.functional.transformer_residual_scaling_rule", "5.4. unit_scaling.formats", "5.4.3. unit_scaling.formats.FPFormat", "5.4.1. unit_scaling.formats.format_to_tuple", "5.4.2. unit_scaling.formats.tuple_to_format", "5.1.21. unit_scaling.functional", "5.1.21.1. unit_scaling.functional.add", "5.1.21.2. unit_scaling.functional.cross_entropy", "5.1.21.3. unit_scaling.functional.dropout", "5.1.21.4. unit_scaling.functional.embedding", "5.1.21.5. unit_scaling.functional.gelu", "5.1.21.6. unit_scaling.functional.layer_norm", "5.1.21.7. unit_scaling.functional.linear", "5.1.21.8. unit_scaling.functional.linear_readout", "5.1.21.9. unit_scaling.functional.matmul", "5.1.21.10. unit_scaling.functional.mse_loss", "5.1.21.11. unit_scaling.functional.residual_add", "5.1.21.12. unit_scaling.functional.residual_apply", "5.1.21.13. unit_scaling.functional.residual_split", "5.1.21.14. unit_scaling.functional.rms_norm", "5.1.21.15. unit_scaling.functional.scaled_dot_product_attention", "5.1.21.16. unit_scaling.functional.silu", "5.1.21.17. unit_scaling.functional.silu_glu", "5.1.21.18. unit_scaling.functional.softmax", "5.1.22. unit_scaling.optim", "5.1.22.5. unit_scaling.optim.Adam", "5.1.22.6. unit_scaling.optim.AdamW", "5.1.22.7. unit_scaling.optim.SGD", "5.1.22.1. unit_scaling.optim.lr_scale_for_depth", "5.1.22.2. unit_scaling.optim.lr_scale_func_adam", "5.1.22.3. unit_scaling.optim.lr_scale_func_sgd", "5.1.22.4. unit_scaling.optim.scaled_parameters", "5.1.23. unit_scaling.parameter", "5.1.23.3. unit_scaling.parameter.OrderedDict", "5.1.23.1. unit_scaling.parameter.Parameter", "5.1.23.4. unit_scaling.parameter.ParameterData", "5.1.23.5. unit_scaling.parameter.Protocol", "5.1.23.6. unit_scaling.parameter.Tensor", "5.1.23.2. unit_scaling.parameter.has_parameter_data", "5.5. unit_scaling.scale", "5.5.1. unit_scaling.scale.scale_bwd", "5.5.2. unit_scaling.scale.scale_fwd", "5.1.2. unit_scaling.transformer_residual_scaling_rule", "5.6. unit_scaling.transforms", "5.6.9. unit_scaling.transforms.Metrics", "5.6.1. unit_scaling.transforms.compile", "5.6.2. unit_scaling.transforms.prune_non_float_tensors", "5.6.3. unit_scaling.transforms.prune_same_scale_tensors", "5.6.4. unit_scaling.transforms.prune_selected_nodes", "5.6.5. unit_scaling.transforms.simulate_format", "5.6.6. unit_scaling.transforms.simulate_fp8", "5.6.7. unit_scaling.transforms.track_scales", "5.6.8. unit_scaling.transforms.unit_scale", "5.7. unit_scaling.transforms.utils", "5.7.1. unit_scaling.transforms.utils.apply_transform", "5.7.2. unit_scaling.transforms.utils.patch_to_expand_modules", "5.7.3. unit_scaling.transforms.utils.replace_node_with_function", "5.7.4. unit_scaling.transforms.utils.torch_nn_modules_to_user_modules", "5.8. unit_scaling.utils", "5.8.2. unit_scaling.utils.ScalePair", "5.8.3. unit_scaling.utils.ScaleTracker", "5.8.4. unit_scaling.utils.ScaleTrackingInterpreter", "5.8.1. unit_scaling.utils.analyse_module", "5.1.3. unit_scaling.visualiser", "Unit Scaling", "3. Limitations", "Almost-scaled dot-product attention", "1. User guide"], "terms": {"unit": [0, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 26, 29, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 71, 83, 85, 89, 90, 91, 92, 98, 102, 105, 106], "scale": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 39, 40, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 77, 82, 83, 85, 86, 87, 89, 90, 91, 92, 98, 102, 103, 105], "i": [0, 2, 4, 6, 7, 8, 9, 10, 11, 12, 15, 16, 17, 18, 19, 22, 23, 24, 25, 28, 33, 40, 47, 48, 49, 50, 52, 53, 54, 58, 60, 61, 62, 63, 65, 66, 67, 73, 74, 77, 82, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 97, 100, 101, 103, 104, 105, 106], "implement": [0, 7, 10, 13, 14, 16, 18, 19, 20, 25, 36, 60, 65, 66, 67, 77, 85, 92, 94, 102, 103, 104, 107], "us": [0, 2, 4, 6, 8, 9, 10, 11, 12, 13, 14, 19, 20, 22, 23, 25, 26, 28, 30, 42, 48, 49, 54, 56, 57, 58, 60, 63, 65, 66, 67, 76, 77, 83, 84, 85, 87, 89, 90, 91, 92, 94, 95, 97, 100, 101, 102, 103, 104, 105, 106, 107], "thin": 0, "wrapper": [0, 64, 92, 94], "around": [0, 77, 92, 94, 105], "exist": [0, 73, 77, 92, 107], "torch": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 17, 18, 23, 24, 25, 45, 46, 47, 49, 54, 60, 63, 65, 66, 67, 69, 70, 71, 72, 74, 75, 77, 83, 85, 89, 90, 91, 92, 94, 95, 97, 100, 102, 103, 105, 107], "nn": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 17, 18, 19, 23, 24, 25, 45, 49, 60, 65, 66, 67, 72, 74, 75, 77, 85, 89, 90, 92, 94, 95, 97, 100, 101, 102, 103, 107], "class": [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 41, 42, 47, 49, 60, 63, 64, 65, 66, 67, 72, 73, 75, 76, 77, 83, 84, 92, 97, 98, 99, 100, 101, 102, 107], "function": [0, 3, 4, 6, 7, 9, 11, 12, 13, 14, 17, 18, 19, 20, 21, 23, 24, 26, 28, 35, 41, 64, 65, 66, 67, 72, 77, 79, 82, 83, 85, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 100, 101, 102, 104, 105, 107], "document": [0, 8, 18, 107], "also": [0, 4, 15, 17, 57, 61, 65, 66, 74, 75, 77, 91, 92, 100, 106, 107], "inherit": 0, "from": [0, 5, 7, 8, 10, 11, 12, 18, 22, 24, 25, 48, 49, 65, 66, 67, 73, 77, 85, 86, 87, 89, 90, 91, 92, 94, 100, 101, 102, 103, 107], "standard": [0, 10, 18, 25, 41, 77, 85, 90, 92, 94, 99, 100, 101, 102, 103, 107], "pytorch": [0, 21, 60, 77, 85, 92, 107], "doc": [0, 2, 105], "modif": [0, 5, 6, 77, 100], "note": [0, 2, 4, 5, 6, 8, 10, 11, 12, 16, 19, 47, 49, 54, 60, 65, 66, 67, 77, 85, 91, 92, 94, 100, 107], "some": [0, 4, 47, 52, 53, 54, 60, 65, 66, 67, 77, 92, 104, 106, 107], "mai": [0, 4, 25, 52, 53, 54, 58, 60, 65, 66, 67, 77, 85, 87, 91, 92, 100, 103, 104, 105, 107], "longer": [0, 77], "relev": [0, 10, 24, 25, 71, 103], "ar": [0, 1, 4, 5, 6, 7, 10, 11, 12, 16, 18, 24, 25, 40, 47, 49, 54, 58, 60, 65, 66, 67, 73, 76, 77, 82, 84, 87, 89, 90, 91, 92, 94, 100, 101, 103, 106, 107], "nevertheless": 0, "The": [0, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 23, 24, 28, 39, 40, 42, 46, 47, 49, 50, 52, 53, 54, 58, 60, 61, 63, 65, 66, 67, 74, 77, 82, 85, 86, 87, 89, 90, 91, 92, 94, 100, 101, 104, 106, 107], "built": [0, 102], "mirror": [0, 77], "close": 0, "possibl": [0, 49, 77, 107], "can": [0, 4, 5, 6, 7, 8, 24, 58, 60, 65, 66, 67, 76, 77, 85, 87, 91, 92, 100, 101, 104, 105, 106, 107], "easili": 0, "swap": [0, 77, 107], "out": [0, 7, 11, 12, 46, 52, 53, 54, 56, 64, 77, 100, 104, 106, 107], "equival": [0, 4, 8, 10, 60, 77, 89, 90, 92, 100, 102, 107], "For": [0, 2, 8, 10, 25, 37, 40, 54, 60, 65, 66, 67, 71, 77, 82, 91, 100, 103, 104, 106, 107], "code": [0, 6, 25, 77, 102, 103, 107], "which": [0, 4, 6, 10, 14, 18, 25, 58, 60, 63, 65, 66, 67, 71, 77, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 100, 101, 103, 106, 107], "follow": [0, 2, 4, 25, 39, 54, 60, 65, 66, 67, 73, 77, 85, 91, 95, 100, 103, 104, 106, 107], "import": [0, 24, 67, 85, 86, 87, 91, 92, 106, 107], "f": [0, 39, 47, 49, 60, 65, 66, 67, 73, 77, 107], "appli": [0, 4, 6, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 28, 39, 46, 47, 48, 50, 51, 52, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 64, 71, 77, 80, 81, 85, 86, 87, 92, 94, 95, 100, 106, 107], "first": [0, 2, 6, 8, 39, 54, 65, 66, 77, 85, 94, 100, 101, 107], "ad": [0, 6, 10, 16, 60, 65, 66, 67, 77, 91, 105, 106, 107], "unit_sc": [0, 104, 105, 107], "uu": [0, 15, 71, 74, 107], "u": [0, 11, 12, 15, 40, 64, 65, 66, 67, 71, 72, 74, 75, 82, 92, 102, 107], "replac": [0, 77, 92, 96, 107], "letter": 0, "those": [0, 4, 47, 77, 89, 90, 104, 107], "assum": [0, 31, 32, 33, 34, 106, 107], "thei": [0, 1, 6, 24, 25, 65, 66, 67, 76, 77, 92, 100, 103, 104, 106, 107], "support": [0, 2, 4, 7, 8, 11, 12, 15, 25, 46, 47, 48, 49, 52, 53, 54, 60, 65, 66, 67, 74, 75, 77, 78, 85, 92, 94, 100, 103], "click": 0, "below": [0, 10, 18, 47, 107], "full": [0, 22, 65, 66, 67, 77, 91, 100, 106, 107], "transform": [1, 6, 10, 11, 12, 19, 20, 23, 24, 25, 39, 40, 52, 53, 64, 65, 66, 67, 77, 82, 103, 104, 105, 106, 107], "seem": [1, 106], "all": [1, 5, 8, 11, 12, 23, 24, 25, 40, 49, 60, 63, 65, 66, 67, 73, 77, 82, 86, 87, 88, 92, 100, 103, 106, 107], "you": [1, 4, 52, 53, 54, 60, 65, 66, 67, 77, 100, 101, 105, 106, 107], "need": [1, 4, 58, 77, 85, 91, 100, 102, 106, 107], "we": [1, 25, 65, 66, 67, 77, 87, 90, 92, 94, 100, 101, 103, 104, 105, 106, 107], "don": [1, 24, 25, 65, 66, 67, 86, 87, 103, 106, 107], "t": [1, 8, 24, 25, 44, 49, 52, 53, 65, 66, 67, 76, 77, 85, 86, 87, 92, 94, 95, 97, 100, 103, 105, 106, 107], "fulli": [1, 101, 106], "understand": [1, 106, 107], "why": [1, 60, 106, 107], "work": [1, 25, 60, 77, 85, 93, 102, 103, 104, 105, 107], "so": [1, 18, 60, 63, 65, 66, 67, 77, 85, 91, 92, 94, 100, 101, 106, 107], "well": [1, 4, 65, 66, 67, 77, 105, 106, 107], "while": [1, 4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77, 100, 106], "notic": [1, 52, 53, 54, 106], "someth": [1, 65, 66, 67, 106, 107], "surpris": [1, 106], "about": [1, 77, 84, 106], "heart": [1, 106], "architectur": [1, 4, 47, 106], "how": [1, 77, 100, 104, 106], "output": [1, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 24, 25, 31, 32, 33, 34, 39, 46, 47, 49, 50, 52, 53, 54, 60, 61, 62, 63, 74, 77, 86, 100, 101, 102, 103, 106, 107], "dougla": [1, 106], "orr": [1, 106], "octob": [1, 106], "2023": [1, 104, 106, 107], "user": [2, 25, 60, 65, 66, 67, 77, 89, 90, 91, 92, 94, 97, 100, 103, 104], "who": [2, 104, 105, 107], "wish": [2, 94, 104, 107], "thi": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 25, 28, 39, 40, 42, 46, 48, 49, 50, 52, 53, 54, 58, 60, 61, 63, 65, 66, 67, 71, 73, 75, 77, 82, 85, 86, 87, 89, 90, 91, 92, 94, 95, 97, 100, 101, 103, 104, 105, 106, 107], "codebas": [2, 104, 107], "setup": 2, "requir": [2, 4, 60, 77, 91, 92, 100, 107], "time": [2, 10, 46, 54, 65, 66, 67, 77, 107], "python3": [2, 102], "m": [2, 7, 9, 11, 12, 17, 18, 54, 85, 89, 90, 91, 92, 94], "venv": 2, "sourc": [2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 76, 77, 78, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 97, 99, 100, 101, 102, 103], "bin": [2, 77], "activ": [2, 10, 17, 61, 89, 106, 107], "pip": [2, 104, 107], "instal": 2, "r": [2, 8, 49, 77, 100], "dev": 2, "txt": 2, "Or": 2, "ipu": [2, 77], "subsequ": [2, 6], "run": [2, 6, 23, 24, 60, 65, 66, 67, 77, 89, 90, 101, 104, 107], "pre": [2, 65, 66, 67, 101, 107], "flight": 2, "check": [2, 71, 76, 77, 78, 100, 106, 107], "help": [2, 104, 105, 106, 107], "see": [2, 4, 8, 9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 60, 61, 63, 76, 77, 92, 100, 101, 104, 105, 106, 107], "command": 2, "id": [2, 22, 65, 66, 67], "recommend": [2, 25, 29, 77, 103, 104, 107], "python": [2, 5, 77, 102], "intepret": 2, "set": [2, 4, 7, 10, 11, 12, 16, 40, 47, 48, 60, 65, 66, 67, 71, 73, 77, 82, 84, 88, 91, 92, 100, 107], "format": [2, 77, 89, 90, 104, 107], "save": [2, 65, 66, 67, 77, 100], "enabl": [2, 60, 67, 71, 77, 79, 92, 94, 100, 107], "consid": [2, 4, 77, 106, 107], "env": [2, 100], "file": 2, "pythonpath": 2, "exampl": [2, 4, 6, 7, 8, 9, 10, 11, 12, 17, 18, 19, 37, 46, 47, 49, 54, 60, 65, 66, 67, 71, 76, 77, 100, 101, 102, 105, 106, 107], "echo": 2, "pwd": 2, "differ": [2, 6, 11, 12, 25, 49, 54, 58, 60, 65, 66, 67, 77, 79, 85, 94, 103, 106, 107], "path": [2, 22, 25, 103], "devcontain": 2, "cd": 2, "make": [2, 24, 60, 77, 92, 94, 97, 106, 107], "html": 2, "view": [2, 24, 73, 77], "_build": 2, "index": [2, 4, 5, 8, 40, 49, 77, 82, 100], "your": [2, 77, 100, 107], "browser": 2, "version": [3, 25, 39, 45, 54, 65, 66, 67, 77, 85, 91, 92, 94, 95, 97, 103, 107], "common": [3, 26, 45, 46, 65, 66, 67, 105, 107], "modul": [3, 5, 6, 7, 8, 10, 11, 12, 16, 19, 23, 24, 25, 28, 35, 49, 54, 60, 77, 83, 85, 86, 87, 89, 90, 91, 92, 94, 95, 97, 101, 102, 103, 106, 107], "mult": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63], "float": [4, 7, 8, 9, 10, 13, 16, 17, 18, 19, 20, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 38, 39, 40, 42, 46, 47, 48, 49, 50, 51, 52, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 68, 69, 70, 71, 77, 80, 81, 82, 84, 86, 87, 91, 99, 101, 103, 107], "1": [4, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 37, 40, 46, 47, 49, 50, 54, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 71, 77, 82, 87, 100, 102, 104, 107], "0": [4, 7, 8, 9, 10, 13, 17, 18, 19, 20, 37, 38, 40, 42, 46, 47, 48, 49, 50, 52, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 71, 76, 77, 82, 100, 102, 106, 107], "weight": [4, 8, 10, 11, 12, 15, 16, 17, 19, 20, 37, 47, 49, 51, 52, 53, 56, 57, 58, 59, 60, 61, 65, 66, 67, 71, 74, 77, 89, 92, 102, 106, 107], "tensor": [4, 8, 11, 12, 15, 16, 18, 22, 24, 25, 38, 39, 42, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 71, 74, 80, 81, 84, 86, 87, 91, 94, 99, 100, 101, 102, 103, 106, 107], "none": [4, 5, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 24, 28, 38, 39, 46, 47, 49, 50, 51, 52, 53, 54, 55, 59, 60, 61, 63, 65, 66, 67, 71, 73, 74, 77, 96, 97, 99, 100, 101, 102, 107], "size_averag": [4, 47, 55], "bool": [4, 7, 8, 10, 11, 12, 13, 16, 17, 20, 24, 25, 38, 47, 48, 49, 55, 60, 61, 65, 66, 67, 71, 77, 96, 100, 101, 102, 103], "ignore_index": [4, 47], "int": [4, 5, 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 25, 38, 40, 42, 43, 44, 46, 47, 49, 51, 59, 63, 65, 66, 67, 74, 76, 77, 82, 84, 100, 101, 103, 107], "100": [4, 37, 47, 77], "reduc": [4, 47, 55, 77, 89, 90], "reduct": [4, 47, 55, 77], "str": [4, 9, 11, 12, 13, 14, 17, 18, 19, 20, 22, 24, 25, 28, 39, 42, 46, 47, 50, 52, 53, 54, 55, 61, 63, 65, 66, 67, 71, 77, 88, 101, 102, 103], "mean": [4, 7, 9, 10, 11, 12, 16, 17, 18, 24, 27, 29, 30, 38, 47, 52, 53, 55, 77, 85, 107], "label_smooth": [4, 47], "comput": [4, 7, 8, 10, 18, 27, 29, 30, 38, 40, 47, 49, 55, 60, 63, 65, 66, 77, 82, 100, 107], "cross": [4, 47, 77, 107], "entropi": [4, 47, 107], "loss": [4, 24, 25, 47, 65, 66, 67, 86, 87, 91, 103, 107], "between": [4, 6, 24, 37, 40, 47, 58, 65, 66, 67, 77, 82, 107], "input": [4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 23, 24, 25, 31, 32, 33, 34, 39, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 57, 58, 59, 60, 61, 62, 63, 65, 66, 67, 77, 80, 81, 87, 90, 91, 92, 100, 101, 102, 103, 106, 107], "logit": [4, 47, 77], "target": [4, 47, 55, 67, 77, 88, 96, 101, 107], "It": [4, 6, 25, 63, 65, 66, 67, 77, 100, 103, 106, 107], "when": [4, 6, 9, 10, 11, 12, 16, 18, 47, 50, 54, 60, 65, 66, 67, 77, 89, 90, 91, 95, 97, 101, 106, 107], "train": [4, 7, 8, 10, 48, 49, 58, 60, 65, 66, 67, 104, 105, 106, 107], "classif": 4, "problem": [4, 104, 107], "c": [4, 10, 46, 47, 60, 76, 77, 100], "If": [4, 7, 8, 10, 11, 12, 40, 47, 48, 49, 52, 53, 54, 60, 63, 65, 66, 67, 73, 77, 82, 89, 90, 100, 101, 106], "provid": [4, 6, 25, 27, 29, 30, 31, 32, 33, 34, 60, 64, 65, 66, 67, 73, 77, 87, 91, 92, 103, 105, 107], "option": [4, 8, 9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 22, 24, 25, 28, 39, 40, 46, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 62, 63, 65, 66, 67, 71, 77, 82, 84, 87, 92, 94, 96, 100, 101, 102, 103], "argument": [4, 9, 50, 54, 60, 65, 66, 67, 77, 100, 101, 107], "should": [4, 18, 39, 47, 56, 58, 60, 65, 66, 67, 75, 77, 89, 90, 91, 92, 94, 95, 100, 101, 102, 107], "1d": 4, "assign": [4, 77], "each": [4, 6, 7, 8, 10, 11, 12, 24, 25, 47, 49, 52, 60, 65, 66, 67, 77, 89, 90, 91, 100, 101, 102, 103, 107], "particularli": [4, 107], "have": [4, 8, 23, 24, 49, 52, 53, 54, 58, 65, 66, 67, 77, 86, 87, 89, 90, 91, 92, 100, 104, 105, 106, 107], "an": [4, 5, 6, 7, 8, 10, 11, 12, 14, 15, 16, 18, 22, 23, 24, 25, 33, 39, 40, 48, 49, 54, 58, 60, 65, 66, 67, 71, 73, 74, 77, 82, 86, 87, 88, 89, 91, 92, 99, 100, 101, 103, 106, 107], "unbalanc": 4, "expect": [4, 10, 16, 49, 77, 89, 90, 107], "contain": [4, 5, 6, 8, 23, 47, 49, 65, 66, 67, 77, 84, 86, 89, 90, 91, 96, 99, 100, 106, 107], "unnorm": [4, 47], "do": [4, 7, 8, 48, 49, 65, 66, 67, 77, 100, 101, 107], "posit": [4, 19, 77, 101], "sum": [4, 16, 18, 47, 63, 77, 106], "gener": [4, 22, 23, 24, 25, 42, 65, 66, 67, 76, 77, 86, 87, 92, 102, 103, 107], "ha": [4, 7, 10, 16, 28, 47, 54, 60, 65, 66, 67, 73, 77, 91, 92, 100, 106, 107], "size": [4, 8, 10, 11, 12, 13, 14, 19, 20, 22, 25, 47, 49, 60, 77, 100, 103], "unbatch": 4, "minibatch": [4, 47], "d_1": [4, 47], "d_2": [4, 47], "d_k": [4, 47], "k": [4, 11, 12, 47, 54, 73, 77, 106], "geq": [4, 47], "dimension": [4, 8, 10, 18, 47, 54, 77], "case": [4, 9, 11, 12, 13, 14, 17, 18, 19, 20, 25, 39, 46, 47, 50, 52, 53, 54, 61, 63, 65, 66, 67, 73, 77, 87, 92, 95, 97, 103, 106, 107], "last": [4, 6, 10, 11, 12, 51, 59, 73, 77, 85], "being": [4, 8, 47, 58, 65, 66, 67, 77, 85, 94, 100, 101], "higher": [4, 60, 77], "dimens": [4, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 51, 52, 53, 54, 59, 63, 77, 100], "per": [4, 10, 16, 47, 65, 66, 67, 71, 77, 100], "pixel": 4, "2d": [4, 77], "imag": [4, 10, 77], "criterion": 4, "either": [4, 47, 73, 77, 100], "indic": [4, 8, 24, 25, 47, 49, 60, 65, 66, 67, 77, 100, 103], "rang": [4, 18, 25, 63, 77, 85, 89, 90, 92, 101, 103, 107], "where": [4, 8, 9, 10, 11, 12, 17, 18, 47, 49, 50, 52, 53, 54, 60, 61, 62, 65, 66, 67, 77, 89, 91, 92, 101, 107], "number": [4, 9, 11, 12, 13, 17, 18, 19, 20, 22, 33, 41, 42, 46, 47, 49, 51, 52, 53, 59, 77, 107], "specifi": [4, 8, 47, 49, 60, 63, 65, 66, 67, 73, 77, 87, 91, 100, 107], "accept": [4, 6, 77, 100], "necessarili": [4, 77], "unreduc": 4, "e": [4, 8, 10, 49, 54, 60, 73, 77, 85, 86, 87, 91, 100], "describ": [4, 7, 10, 16, 40, 47, 77, 82], "ell": 4, "x": [4, 9, 10, 16, 17, 24, 38, 42, 50, 61, 62, 76, 77, 92, 100, 102, 107], "y": [4, 10, 16, 52, 53, 77, 100], "l": [4, 60], "l_1": 4, "dot": [4, 54, 60, 77, 104], "l_n": 4, "top": [4, 92], "quad": 4, "w_": 4, "y_n": 4, "log": [4, 77, 106, 107], "frac": [4, 7, 10, 11, 12, 16, 18, 60, 63, 77], "exp": [4, 18, 63, 77], "x_": [4, 18, 63], "n": [4, 8, 10, 18, 47, 54, 60, 77, 101, 106], "sum_": 4, "cdot": [4, 106], "mathbb": 4, "text": [4, 8, 9, 10, 11, 12, 17, 18, 22, 46, 47, 50, 61, 62, 63, 65, 66, 67, 77], "ignor": [4, 47, 76, 77, 107], "_index": 4, "w": [4, 8, 10, 49, 77, 100], "span": [4, 77], "default": [4, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 24, 25, 39, 40, 46, 47, 48, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 63, 65, 66, 67, 71, 73, 77, 82, 87, 92, 94, 95, 96, 97, 100, 102, 103, 107], "begin": [4, 47, 65, 66, 67, 73, 77, 107], "end": [4, 5, 6, 19, 47, 65, 66, 67, 73, 77, 107], "logsoftmax": 4, "nllloss": 4, "probabl": [4, 7, 13, 19, 20, 47, 48, 60, 77], "label": [4, 22, 25, 103], "beyond": [4, 65, 66, 107], "singl": [4, 6, 10, 57, 65, 66, 67, 77, 85, 106], "item": [4, 73, 77], "blend": 4, "smooth": [4, 47], "etc": 4, "w_c": 4, "y_": 4, "perform": [4, 6, 60, 63, 65, 66, 67, 71, 77, 100, 107], "better": [4, 60, 92, 106], "allow": [4, 6, 58, 77, 85, 106, 107], "optim": [4, 60, 85, 94, 95, 104], "onli": [4, 10, 25, 31, 32, 33, 34, 42, 47, 54, 60, 68, 76, 77, 80, 81, 85, 86, 89, 90, 91, 94, 100, 101, 103, 105, 107], "too": [4, 60, 77, 89, 90], "restrict": [4, 54], "paramet": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 39, 40, 46, 47, 48, 49, 50, 52, 53, 54, 56, 57, 58, 60, 61, 62, 63, 64, 65, 66, 67, 71, 80, 81, 82, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 101, 102, 103, 104], "manual": [4, 6, 47, 65, 66, 67, 77, 85, 92, 107], "rescal": [4, 18, 47], "given": [4, 5, 6, 8, 19, 22, 40, 42, 44, 47, 49, 60, 64, 73, 76, 77, 82, 86, 87, 88, 89, 90, 96, 100, 102, 107], "point": [4, 24, 25, 42, 49, 60, 77, 86, 87, 91, 103, 107], "dtype": [4, 8, 10, 11, 12, 47, 52, 53, 54, 60, 63, 77, 100, 101], "deprec": [4, 47, 77], "By": [4, 47, 77, 92, 95, 97, 107], "averag": [4, 40, 47, 65, 66, 82], "over": [4, 6, 10, 16, 47, 56, 57, 58, 60, 65, 66, 67, 77, 92, 100, 106, 107], "element": [4, 7, 10, 16, 18, 39, 47, 48, 50, 55, 60, 63, 73, 77, 106], "batch": [4, 8, 10, 22, 25, 47, 49, 54, 65, 66, 67, 77, 103], "multipl": [4, 47, 54, 77, 80, 81, 87, 107], "sampl": [4, 7, 8, 11, 12, 22, 47, 48, 49, 77], "field": [4, 47, 75, 77], "fals": [4, 7, 8, 10, 11, 12, 16, 17, 24, 38, 47, 48, 49, 60, 61, 65, 66, 67, 71, 73, 77, 100, 102], "instead": [4, 47, 65, 66, 67, 77], "true": [4, 7, 8, 10, 11, 12, 16, 24, 25, 47, 48, 49, 60, 65, 66, 67, 71, 73, 77, 89, 90, 96, 100, 101, 102, 103], "valu": [4, 6, 8, 10, 11, 12, 16, 18, 24, 25, 37, 42, 47, 56, 57, 58, 60, 65, 66, 67, 71, 73, 77, 86, 89, 90, 92, 100, 101, 103, 106, 107], "doe": [4, 5, 6, 8, 47, 54, 60, 65, 66, 67, 71, 73, 77, 85, 94, 107], "contribut": [4, 8, 40, 47, 49, 56, 57, 58, 82], "gradient": [4, 8, 9, 11, 12, 13, 14, 17, 18, 19, 20, 31, 32, 33, 34, 39, 46, 47, 49, 50, 52, 53, 54, 58, 61, 63, 65, 66, 67, 77, 84, 89, 92, 100, 106, 107], "non": [4, 8, 13, 20, 41, 42, 47, 49, 54, 58, 60, 71, 77, 87, 100, 107], "applic": [4, 47, 77, 100], "observ": [4, 47, 77], "depend": [4, 24, 47, 54, 60, 71, 77, 92, 101, 106, 107], "return": [4, 6, 15, 18, 22, 23, 24, 25, 27, 28, 29, 30, 31, 32, 33, 34, 37, 39, 40, 44, 47, 54, 56, 58, 60, 62, 63, 65, 66, 67, 71, 73, 74, 76, 77, 80, 81, 82, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 100, 101, 102, 103, 107], "taken": [4, 107], "process": [4, 8, 47, 65, 66, 67, 86, 92, 101], "meantim": [4, 47], "two": [4, 31, 47, 54, 58, 60, 65, 66, 67, 77, 85, 100, 107], "arg": [4, 6, 25, 39, 47, 65, 66, 67, 75, 77, 96, 100, 101, 103], "overrid": [4, 47, 100], "A": [4, 5, 6, 7, 8, 13, 14, 18, 19, 20, 47, 49, 54, 60, 63, 65, 66, 67, 77, 84, 85, 104, 106, 107], "amount": [4, 47], "becom": [4, 6, 47, 77], "mixtur": [4, 47], "origin": [4, 17, 47, 61, 77, 81, 101, 106], "ground": [4, 47], "truth": [4, 47], "uniform": [4, 47, 77], "distribut": [4, 7, 9, 47, 48, 50, 77, 105, 106, 107], "rethink": [4, 47], "incept": [4, 47], "vision": [4, 47], "multipli": [4, 9, 13, 17, 18, 47, 50, 54, 60, 61, 62, 63, 77, 106, 107], "chang": [4, 5, 6, 8, 9, 13, 17, 18, 24, 47, 50, 60, 61, 62, 63, 65, 66, 67, 77, 85, 87, 106, 107], "shape": [4, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 46, 47, 49, 50, 52, 53, 60, 61, 62, 63, 77, 100], "nonlinear": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 106], "typic": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77, 95, 97, 107], "high": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77], "correspond": [4, 8, 9, 13, 17, 18, 25, 28, 44, 47, 49, 50, 60, 61, 62, 63, 65, 66, 67, 73, 77, 92, 100, 103], "sharper": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63], "low": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 77, 104, 107], "temperatur": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63, 106], "flatter": [4, 9, 13, 17, 18, 47, 50, 60, 61, 62, 63], "same": [4, 6, 7, 9, 10, 11, 12, 17, 18, 40, 47, 54, 60, 62, 77, 82, 87, 89, 90, 100, 106, 107], "otherwis": [4, 65, 66, 67, 73, 77, 100], "scalar": [4, 10, 54, 77, 80, 81, 99, 107], "align": [4, 47, 60, 65, 66, 67, 77], "randn": [4, 7, 9, 10, 11, 12, 17, 18, 46, 47, 77, 102, 107], "3": [4, 8, 9, 10, 18, 46, 47, 49, 50, 65, 66, 67, 77, 107], "5": [4, 6, 7, 8, 9, 10, 16, 37, 46, 47, 48, 49, 50, 52, 60, 65, 66, 67, 77, 104, 107], "requires_grad": [4, 8, 47, 77, 100], "empti": [4, 77], "long": [4, 77], "random_": [4, 77], "backward": [4, 11, 12, 23, 24, 25, 42, 47, 54, 58, 65, 66, 67, 77, 79, 80, 84, 86, 87, 89, 90, 91, 92, 99, 100, 101, 102, 103, 106, 107], "softmax": [4, 13, 19, 20, 47, 60, 77, 104, 106, 107], "dim": [4, 18, 38, 47, 60, 63, 77], "iter": [5, 65, 66, 67, 71, 73, 77, 88, 94], "modulelist": [5, 6], "automat": [5, 6, 24, 25, 60, 77, 91, 92, 95, 100, 102, 103, 107], "configur": [5, 6], "depth": [5, 6, 40, 68, 82, 107], "sake": [5, 6, 83, 91], "track": [5, 6, 24, 25, 77, 86, 91, 103], "caus": [5, 6, 77, 107], "after": [5, 6, 28, 54, 65, 66, 67, 77, 92, 100], "initi": [5, 6, 8, 10, 11, 12, 16, 40, 49, 65, 66, 67, 77, 82, 107], "construct": [5, 6, 8, 15, 49, 74, 77], "like": [5, 6, 13, 19, 20, 65, 66, 67, 73, 77, 100, 107], "regular": [5, 7, 66, 71, 92], "list": [5, 6, 8, 10, 49, 65, 66, 67, 77, 94, 101, 105], "properli": [5, 106], "regist": [5, 6, 65, 66, 67, 77], "visibl": 5, "method": [5, 6, 25, 65, 71, 73, 77, 85, 86, 87, 89, 90, 91, 94, 100, 101, 102, 103, 104, 107], "add": [5, 56, 65, 66, 67, 77, 92, 104, 105, 107], "append": [5, 6, 19, 54, 77, 107], "extend": [5, 72, 100], "self": [5, 13, 17, 20, 60, 61, 65, 66, 67, 76, 77, 101, 102, 104, 107], "insert": [5, 73, 89, 90], "befor": [5, 63, 65, 66, 67, 77, 89, 90, 100, 101, 106, 107], "ani": [6, 7, 8, 9, 10, 11, 12, 17, 18, 24, 25, 33, 52, 53, 60, 64, 65, 66, 67, 71, 77, 88, 91, 92, 94, 95, 96, 100, 101, 102, 103, 104, 105, 107], "sequenti": [6, 19], "order": [6, 58, 60, 65, 66, 67, 73, 77, 101, 107], "pass": [6, 8, 23, 24, 25, 42, 58, 60, 65, 66, 67, 71, 76, 77, 79, 80, 81, 84, 89, 90, 91, 92, 95, 99, 100, 101, 102, 103, 106, 107], "constructor": 6, "altern": [6, 40, 64, 82, 107], "ordereddict": 6, "forward": [6, 7, 23, 24, 25, 42, 60, 77, 79, 80, 81, 84, 85, 89, 90, 91, 92, 94, 95, 99, 100, 101, 102, 103, 106, 107], "chain": [6, 77, 91], "final": [6, 12, 53, 54, 91, 92], "call": [6, 7, 23, 24, 25, 60, 65, 66, 67, 77, 85, 86, 87, 89, 90, 91, 92, 94, 95, 96, 97, 100, 101, 102, 103], "sequenc": [6, 13, 20, 22, 25, 51, 77, 103, 106], "treat": [6, 10, 18, 77], "whole": [6, 106], "store": [6, 8, 77, 101], "submodul": 6, "what": [6, 60, 65, 66, 67, 86, 100, 104, 106], "": [6, 14, 31, 32, 33, 34, 52, 53, 54, 60, 65, 66, 67, 73, 77, 91, 100, 101, 102, 106, 107], "exactli": [6, 107], "sound": 6, "On": [6, 11, 12, 54, 65, 66, 67], "other": [6, 21, 46, 54, 60, 65, 66, 67, 77, 91, 105, 107], "hand": 6, "layer": [6, 10, 11, 12, 13, 14, 16, 19, 20, 40, 51, 58, 65, 66, 67, 82, 85, 92, 107], "connect": [6, 20, 56, 57, 58, 92, 106, 107], "cascad": 6, "wai": [6, 77, 85, 92, 100, 107], "creat": [6, 8, 71, 73, 77], "small": [6, 106, 107], "model": [6, 21, 24, 25, 58, 65, 66, 67, 71, 86, 87, 89, 90, 91, 92, 98, 103, 104, 106], "conv2d": 6, "20": [6, 7, 10, 11, 12, 46, 77], "relu": [6, 102], "64": [6, 60, 107], "second": [6, 8, 54, 65, 66, 100], "abov": [6, 60, 77, 106, 107], "conv1": 6, "relu1": 6, "conv2": 6, "relu2": 6, "p": [7, 8, 15, 48, 49, 54, 60, 74, 77], "inplac": [7, 17, 48, 61, 65, 66, 67, 100, 102], "zero": [7, 8, 10, 15, 48, 49, 60, 65, 66, 67, 74, 77, 92, 100], "chosen": [7, 31, 32, 33, 34, 60], "independ": [7, 71, 77, 92], "bernoulli": [7, 48, 77], "channel": [7, 10, 77], "everi": [7, 18, 40, 77, 82, 100, 101], "proven": 7, "effect": [7, 77, 89, 90, 106, 107], "techniqu": [7, 107], "prevent": [7, 63, 77, 100], "co": [7, 77], "adapt": [7, 107], "neuron": 7, "paper": [7, 10, 16, 40, 65, 66, 82, 92, 104, 106, 107], "improv": [7, 60, 65, 66, 67], "neural": [7, 17, 61], "network": [7, 12, 17, 53, 61, 65, 66, 67], "featur": [7, 52, 53, 54, 104, 105, 107], "detector": 7, "_": [7, 60, 65, 66, 67, 100], "furthermor": [7, 77], "factor": [7, 14, 31, 32, 33, 34, 58, 60, 67, 68, 69, 70, 71, 79, 80, 81, 92, 106, 107], "dure": [7, 8, 24, 25, 49, 60, 77, 100, 103, 107], "evalu": [7, 10, 60, 101, 107], "simpli": [7, 107], "ident": [7, 107], "oper": [7, 10, 16, 24, 25, 26, 48, 52, 53, 54, 60, 63, 77, 79, 87, 92, 100, 102, 103, 105, 106, 107], "place": [7, 48, 49, 77, 100, 107], "2": [7, 8, 9, 10, 16, 17, 18, 38, 46, 49, 50, 52, 53, 54, 60, 65, 66, 67, 77, 87, 100, 102, 107], "16": [7, 77, 87, 107], "num_embed": 8, "embedding_dim": [8, 10, 49], "padding_idx": [8, 49], "max_norm": [8, 49], "norm_typ": [8, 49], "scale_grad_by_freq": [8, 49], "spars": [8, 18, 49, 52, 53, 54, 77], "_weight": 8, "_freez": 8, "devic": [8, 10, 11, 12, 52, 53, 54, 60, 77, 101], "lookup": [8, 49], "tabl": [8, 49], "look": [8, 49, 54, 65, 66, 67, 77, 101, 106, 107], "up": [8, 49, 77, 100, 101, 106, 107], "fix": [8, 49, 77, 106, 107], "dictionari": [8, 65, 66, 67, 73, 77, 92], "often": [8, 49, 77, 107], "word": [8, 49], "retriev": [8, 28, 49, 77, 100, 101], "them": [8, 18, 63, 77, 92, 100, 101, 105], "vector": [8, 49, 54, 77], "entri": [8, 49, 65, 66, 67, 77], "therefor": [8, 49, 77], "updat": [8, 49, 71, 73, 77, 107], "remain": [8, 49], "pad": [8, 13, 20, 49, 77], "newli": 8, "anoth": [8, 77], "norm": [8, 11, 12, 15, 49, 74, 77], "larger": [8, 49, 56, 57, 58, 77, 92, 107], "than": [8, 49, 58, 60, 65, 66, 67, 77, 85, 91, 92, 101], "renorm": [8, 49, 77], "invers": [8, 49, 77, 106], "frequenc": [8, 49], "mini": [8, 10, 49], "matrix": [8, 49, 54, 60, 77], "more": [8, 49, 60, 61, 63, 65, 66, 67, 77, 91, 100, 105, 106, 107], "detail": [8, 9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 61, 63, 65, 66, 76, 77, 92, 100, 101, 105, 107], "regard": [8, 49, 65, 66, 107], "learnabl": [8, 10, 11, 12, 16], "mathcal": [8, 11, 12], "type": [8, 11, 12, 19, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 37, 39, 40, 46, 49, 56, 58, 60, 62, 63, 65, 66, 67, 71, 76, 77, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 94, 95, 96, 101, 102, 103], "inttensor": [8, 77], "longtensor": [8, 49, 77], "arbitrari": [8, 23, 24, 25, 49, 60, 77, 92, 99, 103, 107], "extract": [8, 49, 101], "h": [8, 10, 77], "_dim": 8, "10": [8, 10, 15, 37, 46, 49, 74, 77, 102, 107], "4": [8, 14, 46, 49, 77, 100, 102, 107], "9": [8, 49, 65, 66, 67, 77], "xdoctest": [8, 49, 67, 100], "ignore_w": [8, 49], "determinist": [8, 49, 60, 77], "0251": 8, "6902": 8, "7172": 8, "6431": 8, "0748": 8, "6969": 8, "4970": 8, "3448": 8, "9685": 8, "3677": 8, "7265": 8, "1685": 8, "4362": 8, "4004": 8, "9400": 8, "9124": 8, "3616": 8, "1151": 8, "0000": [8, 49, 77], "1535": 8, "0309": 8, "9315": 8, "1655": 8, "9897": 8, "0635": 8, "7895": 8, "7089": 8, "0364": 8, "6778": 8, "5803": 8, "2678": 8, "no_grad": [8, 65, 66, 67, 77], "ones": [8, 10, 16, 49, 60, 77], "classmethod": 8, "from_pretrain": 8, "freez": 8, "instanc": [8, 10, 65, 66, 67, 77, 91], "floattensor": [8, 77], "get": [8, 71, 73, 77, 105, 107], "learn": [8, 10, 11, 12, 17, 61, 65, 66, 67, 71, 106, 107], "pretrain": 8, "6": [8, 46, 77, 102], "1000": [8, 37], "3000": 8, "constraint": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 101, 104, 107], "to_output_scal": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104], "approxim": [9, 17, 50, 61, 107], "gaussian": [9, 17, 50, 61], "error": [9, 17, 24, 25, 50, 55, 60, 61, 77, 100, 101, 103], "linear": [9, 12, 17, 50, 53, 61, 62, 77, 89, 90, 102, 104, 107], "phi": [9, 50], "cumul": [9, 50], "tanh": [9, 50, 77], "estim": [9, 10, 50], "sqrt": [9, 10, 11, 12, 16, 38, 50, 60, 65, 66, 77], "pi": [9, 50, 77], "044715": [9, 50], "algorithm": [9, 60, 65, 66], "name": [9, 11, 12, 13, 14, 17, 18, 19, 20, 22, 25, 28, 39, 46, 50, 52, 53, 54, 61, 63, 77, 101, 103], "In": [9, 11, 12, 13, 14, 17, 18, 19, 20, 24, 39, 46, 50, 52, 53, 54, 58, 60, 61, 63, 73, 77, 92, 95, 100, 106, 107], "must": [9, 11, 12, 13, 14, 17, 18, 19, 20, 28, 37, 39, 46, 50, 52, 53, 54, 60, 61, 63, 77, 85, 86, 87, 100, 106, 107], "one": [9, 11, 12, 13, 14, 17, 18, 19, 20, 28, 39, 46, 50, 52, 53, 54, 58, 60, 61, 63, 65, 66, 67, 77, 85, 100, 107], "gmean": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104, 107], "hmean": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104], "amean": [9, 11, 12, 13, 14, 17, 18, 19, 20, 39, 46, 50, 52, 53, 54, 61, 63, 104], "to_grad_input_scal": [9, 11, 12, 14, 17, 18, 39, 50, 52, 53, 61, 63, 104], "normalized_shap": [10, 16, 51, 59], "ep": [10, 16, 38, 51, 59, 65, 66], "1e": [10, 16, 51, 59, 65, 66, 67, 77, 107], "05": [10, 16, 51, 59, 77, 87, 107], "elementwise_affin": [10, 16], "bia": [10, 11, 12, 15, 51, 52, 53, 60, 74, 102, 107], "normal": [10, 16, 42, 51, 59, 77, 106, 107], "mathrm": [10, 38, 65, 66, 106], "var": [10, 77], "epsilon": [10, 16, 38, 65, 66], "gamma": [10, 16, 65, 66, 67], "beta": [10, 52, 53, 54, 60, 65, 66, 77, 104, 107], "deviat": [10, 77, 99, 100, 101, 102, 107], "calcul": [10, 40, 68, 69, 70, 82], "d": [10, 52, 53, 73, 77, 100, 102, 107], "affin": [10, 77], "via": [10, 77, 86, 87, 91, 92, 101, 107], "bias": [10, 92], "unbias": 10, "unlik": [10, 24, 25, 77, 103, 107], "entir": [10, 77, 107], "plane": 10, "statist": 10, "data": [10, 11, 12, 15, 22, 52, 53, 63, 74, 77, 84, 101], "both": [10, 40, 54, 60, 77, 82, 100, 102, 107], "mode": [10, 60, 77, 100], "_shape": 10, "ldot": [10, 65, 66, 67, 77], "integ": [10, 46, 77, 100], "singleton": [10, 77], "specif": [10, 60, 65, 66, 67, 77, 89, 90, 101], "denomin": [10, 16, 65, 66], "numer": [10, 16, 24, 25, 49, 65, 66, 83, 89, 90, 103, 107], "stabil": [10, 16, 65, 66], "boolean": [10, 16, 60, 77, 100], "addit": [10, 11, 12, 18, 46, 52, 53, 65, 66, 67, 77, 91, 107], "nlp": 10, "sentence_length": 10, "embed": [10, 19, 40, 82, 104], "layer_norm": [10, 104], "three": [10, 32, 34, 60, 106], "spatial": 10, "shown": [10, 24, 107], "in_featur": [11, 12], "out_featur": [11, 12], "weight_mup_typ": [11, 12], "liter": [11, 12, 15, 74, 77], "incom": [11, 12, 52, 53], "tensorfloat32": [11, 12, 52, 53, 54], "certain": [11, 12, 51, 54, 59, 65, 66, 67, 87, 101], "rocm": [11, 12, 54], "float16": [11, 12, 54, 60, 65, 66, 67, 77], "precis": [11, 12, 54, 60, 104, 107], "_featur": [11, 12, 52, 53], "h_": [11, 12], "includ": [11, 12, 52, 53, 77, 105, 106], "30": [11, 12, 77], "128": [11, 12, 60, 77], "print": [11, 12, 102, 107], "appropri": [12, 71, 77], "hidden_s": [13, 14, 19, 20, 102], "head": [13, 19, 20, 106], "is_caus": [13, 20, 60], "dropout_p": [13, 19, 20, 60], "multi": [13, 20], "attent": [13, 19, 20, 22, 40, 60, 82, 104, 107], "warn": [13, 19, 20, 25, 60, 103], "here": [13, 19, 20, 60, 77, 92, 94, 105, 107], "give": [13, 19, 20, 77, 92, 104, 106, 107], "incorrect": [13, 19, 20, 77, 100], "hidden": [13, 14, 19, 20], "causal": [13, 20, 60], "mask": [13, 19, 20, 22, 60, 77], "post": [13, 19, 20, 65, 66, 67, 77, 106], "dropout": [13, 19, 20, 60, 104], "expansion_factor": 14, "swiglu": 14, "intermedi": [14, 65, 66, 67, 91, 101], "increas": [14, 60, 100], "rel": [14, 20, 40, 56, 58, 77, 82, 87, 106], "mup_typ": [15, 71, 74], "mup_scaling_depth": [15, 71, 74], "\u03bcp": [15, 40, 72, 74, 75, 82], "object": [15, 65, 66, 67, 73, 74, 75, 77, 84, 91, 100], "annot": [15, 23, 74, 102], "parameterdata": [15, 68, 69, 70, 71, 74, 78], "protocol": [15, 74, 75, 78], "assert": [15, 60, 74], "tupl": [16, 22, 28, 38, 43, 44, 52, 58, 59, 65, 66, 67, 77, 96, 100, 101, 102, 107], "rm": [16, 59, 104], "normalis": [16, 58], "trail": 16, "root": 16, "squar": [16, 55, 60, 65, 66, 77], "sigmoid": [17, 61, 62, 77], "known": [17, 61, 105], "swish": [17, 61], "sigma": [17, 61, 62, 77], "logist": [17, 61, 62], "gelu": [17, 61, 92, 104, 107], "wa": [17, 61, 77, 101], "coin": [17, 61], "reinforc": [17, 61], "gate": [17, 61, 62], "experi": [17, 61, 107], "later": [17, 61, 77], "lie": [18, 63], "li": 18, "adjust": 18, "accordingli": 18, "defin": [18, 28, 60, 63, 64, 65, 66, 67, 76, 77, 92, 100, 107], "x_i": [18, 63], "sum_j": [18, 63], "x_j": [18, 63], "unspecifi": [18, 65, 66, 67, 77], "inf": [18, 60, 77], "along": [18, 22, 63, 65, 66, 67, 77, 107], "slice": [18, 63, 77, 87, 101], "vocab_s": 19, "residual_sc": 19, "callabl": [19, 39, 40, 57, 65, 66, 67, 71, 77, 82, 88, 92, 94, 95, 96, 101, 102, 107], "transformer_residual_scaling_rul": [19, 104], "local": [19, 102], "_tau": 19, "decod": 19, "current": [19, 25, 60, 65, 66, 67, 77, 85, 92, 94, 101, 103, 104, 105, 107], "just": [19, 22, 65, 66, 67, 77, 100, 101, 106, 107], "demonstr": [19, 104], "lack": [19, 73], "kei": [19, 60, 73, 77, 84, 91, 92, 104], "usag": [19, 35, 77], "infer": [19, 77], "token": [19, 22, 25, 103], "vocabulari": 19, "residu": [19, 20, 40, 56, 57, 58, 82, 92, 106, 107], "scheme": [19, 40, 77, 82], "control": [19, 60, 64, 107], "trunk": 19, "core": [19, 104, 106], "mhsa_tau": 20, "mlp_tau": 20, "prenorm": 20, "http": [20, 77, 104, 107], "arxiv": 20, "org": 20, "ab": [20, 77], "2002": 20, "04745": 20, "branch": [20, 56, 57, 58, 107], "skip": [20, 56, 57, 58, 65, 66, 67, 100, 106, 107], "mlp": [20, 40, 82, 102, 104, 107], "tool": [21, 107], "analys": [21, 102, 107], "metric": [21, 23, 24, 91, 104], "within": [21, 77, 85, 94, 101, 107], "pretrainedtokenizerbas": [22, 25, 103], "batch_siz": [22, 25, 100, 103, 107], "seq_len": [22, 25, 103], "dataset_path": [22, 25, 103], "wikitext": [22, 25, 103], "dataset_nam": [22, 25, 103], "103": [22, 25, 103], "v1": [22, 25, 103], "shuffle_buffer_s": 22, "10000": 22, "seed": 22, "1472": 22, "dataset": [22, 25, 103], "shift": [22, 77], "length": [22, 25, 77, 103, 106], "huggingfac": [22, 25, 103], "visualis": [22, 24, 91, 104], "random": [22, 60, 77, 100], "chunk": [22, 77], "determin": [22, 54, 107], "10_000": 22, "shuffl": 22, "input_idx": 22, "attn_mask": [22, 25, 60, 103], "g": [23, 24, 40, 62, 65, 66, 77, 82, 86, 87, 91, 100, 106], "graph": [23, 24, 25, 65, 66, 77, 86, 87, 88, 91, 92, 94, 95, 96, 97, 100, 101, 103], "datafram": 23, "convert": [23, 43, 77, 97, 107], "fx": [23, 24, 85, 86, 87, 88, 91, 94, 95, 97, 101, 102, 107], "panda": 23, "indend": 23, "been": [23, 24, 28, 77, 86, 87, 92, 100, 107], "track_scal": [23, 24, 25, 86, 87, 103, 104], "possibli": [23, 24], "scales_graph": [23, 24, 25, 86, 87, 91, 103], "result": [23, 24, 56, 58, 77, 91, 101], "inform": [23, 60, 77, 84, 87, 91], "intern": [23, 89, 90, 91], "plot": [23, 25, 91, 103, 104], "pd": 23, "titl": 24, "mean_ab": [24, 84], "prune_same_scal": 24, "show_arrow": 24, "show_error_bar": 24, "show_zero_tensor": 24, "xmin": 24, "xmax": 24, "ax": [24, 25, 103], "matplotlib": [24, 25, 103], "intend": [24, 65, 66, 67, 87, 91, 99, 100], "inpt": [24, 86, 87, 91], "prune": [24, 25, 86, 87, 88, 103], "deem": [24, 25, 87, 103], "perspect": [24, 25, 103, 107], "faint": [24, 25, 103], "colour": [24, 25, 103], "horizont": [24, 25, 103], "line": [24, 25, 103, 107], "row": [24, 25, 49, 77, 103], "repres": [24, 25, 77, 84, 88, 91, 92, 94, 96, 99, 100, 101, 102, 103, 107], "bar": [24, 25, 103], "maximum": [24, 25, 42, 49, 77, 103], "minimum": [24, 25, 42, 77, 103], "seen": [24, 25, 77, 103, 106, 107], "show": [24, 104, 107], "axi": [24, 77], "abs_mean": [24, 84], "std": [24, 77, 84], "abs_max": [24, 84], "abs_min": [24, 84], "numel": [24, 77, 84], "practic": [24, 77, 91, 95, 107], "reshap": [24, 77, 87], "clearer": 24, "arrow": 24, "denot": [24, 77, 107], "max": [24, 65, 66, 77, 107], "min": [24, 77, 107], "displai": 24, "plot_kwarg": [25, 103], "experiment": [25, 77, 89, 90, 92, 103], "conveni": [25, 94, 103], "combin": [25, 52, 53, 54, 56, 57, 77, 85, 103, 107], "example_batch": [25, 103, 104], "wide": [25, 92, 103], "interfac": [25, 92, 94, 103], "futur": [25, 77, 85, 100, 103], "now": [25, 77, 91, 95, 100, 103, 107], "re": [25, 63, 77, 103, 104, 105, 106, 107], "base": [25, 52, 53, 60, 67, 71, 76, 100, 103, 107], "templat": [25, 103], "tracked_model": [25, 103], "keyword": [25, 60, 77, 101, 103], "alias": 26, "arithmet": 27, "group": [27, 28, 29, 30, 64, 65, 66, 67, 71, 77], "constrain": [27, 29, 30, 92, 107], "constraint_nam": 28, "rais": [28, 60, 73, 77, 100], "valueerror": 28, "geometr": [29, 77, 107], "harmon": 30, "xavier": [30, 107], "glorot": [30, 107], "output_scal": [31, 32, 33, 34, 39, 107], "grad_input_scal": [31, 33, 39, 107], "select": [31, 32, 33, 34, 60, 77, 106], "op": [31, 32, 33, 34, 65, 66, 67, 77, 85, 89, 90, 100, 101, 105, 107], "equal": [31, 32, 33, 34, 40, 49, 56, 57, 58, 77, 82, 107], "left_grad_scal": [32, 34], "right_grad_scal": [32, 34], "left": [32, 34, 54, 60, 77, 107], "right": [32, 34, 54], "grad": [33, 39, 52, 65, 66, 67, 77, 80, 81, 100, 107], "compon": 35, "advanc": 35, "librari": [35, 91, 92, 104, 105, 107], "alpha": [37, 46, 77], "lower": [37, 60, 65, 66, 67], "upper": [37, 60, 77], "interpol": [37, 77], "logarithm": 37, "space": 37, "constant": 37, "ratio": [37, 40, 56, 57, 58, 77, 82, 107], "limit": [37, 60, 91, 104], "keepdim": [38, 77], "wise": [39, 50, 55], "take": [39, 60, 77, 92, 107], "its": [39, 54, 65, 66, 77, 81, 84, 85, 96, 100, 104, 107], "kwarg": [39, 65, 66, 67, 75, 77, 96, 100, 101], "residual_mult": [40, 82], "residual_attn_ratio": [40, 82], "tau": [40, 56, 57, 58, 67, 82, 107], "rule": [40, 60, 64, 65, 66, 67, 77, 82, 106], "stack": [40, 82, 107], "start": [40, 77, 82, 101, 107], "ensur": [40, 60, 77, 82, 100, 101, 106, 107], "varianc": [40, 82, 106], "attn": [40, 82], "hyperparamet": [40, 82, 107], "total": [40, 77, 82], "appendix": [40, 82], "ffn": [40, 82], "fn": [40, 57, 82, 95], "simul": [41, 77, 89, 90], "exponent_bit": [42, 43, 44], "mantissa_bit": [42, 43, 44], "round": [42, 77, 89, 90], "stochast": [42, 65, 89, 90], "srbit": 42, "represent": [42, 77, 89, 90], "properti": [42, 77, 106], "bit": [42, 77], "max_absolute_valu": 42, "absolut": [42, 77], "min_absolute_norm": 42, "min_absolute_subnorm": 42, "subnorm": 42, "quantis": [42, 89, 90], "differenti": [42, 65, 66, 67, 77, 100], "quantise_bwd": 42, "quantise_fwd": 42, "fpformat": [43, 44, 89, 104], "_i": 46, "broadcast": [46, 54, 60, 77], "promot": 46, "complex": [46, 77, 101], "to_left_grad_scal": [46, 54, 104], "to_right_grad_scal": [46, 54, 104], "0202": 46, "0985": 46, "3506": 46, "6056": 46, "21": 46, "19": 46, "3944": 46, "b": [46, 52, 53, 67, 77, 100], "9732": 46, "3497": 46, "6245": 46, "4022": 46, "3743": 46, "7724": 46, "5811": 46, "8017": 46, "7695": 46, "3930": 46, "3672": 46, "1450": 46, "18": [46, 77], "6971": 46, "0736": 46, "17": 46, "0994": 46, "3216": 46, "7845": 46, "1610": 46, "1868": 46, "4090": 46, "8": [46, 60, 65, 66, 77, 107], "9902": 46, "3667": 46, "7": [46, 77], "3925": 46, "6147": 46, "crossentropyloss": [47, 104, 107], "predict": [47, 77], "section": [47, 77, 104, 107], "divid": [47, 77], "randint": [47, 77], "int64": [47, 77], "dictionaryand": 49, "analyt": 49, "respect": [49, 65, 66, 67, 77, 100], "column": [49, 77], "modifi": [49, 65, 66, 67, 77, 100, 106], "under": [49, 106], "v": [49, 65, 66, 67, 73, 77, 106], "embedding_matrix": 49, "rand": [49, 60, 77], "8490": 49, "9625": 49, "6753": 49, "9666": 49, "7761": 49, "6108": 49, "6246": 49, "9751": 49, "3618": 49, "4161": 49, "2419": 49, "7383": 49, "0237": 49, "7794": 49, "0528": 49, "3385": 49, "8612": 49, "1867": 49, "zero_": [49, 77], "5609": 49, "5384": 49, "8720": 49, "6262": 49, "2438": 49, "7471": 49, "layernorm": [51, 104], "scale_pow": 52, "xa": [52, 53], "layout": [52, 53, 54, 77, 101], "autograd": [52, 53, 54, 65, 66, 67, 77, 100], "miss": [52, 53, 54, 105], "pleas": [52, 53, 54, 60, 65, 66, 77, 100, 107], "open": [52, 53, 54, 106], "request": [52, 53, 54, 65, 66, 67, 105], "power": 52, "product": [54, 60, 77, 104], "behavior": [54, 65, 66, 67, 77, 100], "prepend": [54, 65, 66, 67, 77], "purpos": [54, 77, 89, 90, 107], "remov": [54, 65, 66, 67, 73, 77, 87, 96, 107], "least": [54, 77], "thu": [54, 107], "j": [54, 77], "logic": 54, "valid": [54, 77, 92, 107], "even": [54, 77, 107], "though": [54, 58, 77, 85, 94, 105, 107], "particular": [54, 77, 92, 106, 107], "mm": [54, 77], "measur": [55, 107], "mseloss": 55, "togeth": [56, 107], "conjunct": [56, 58, 91, 95, 107], "residual_split": [56, 57, 104, 107], "come": [56, 77, 85], "favor": [56, 57, 58], "maintain": 57, "residual_add": [57, 58, 104, 107], "split": [58, 77], "prior": [58, 60, 100], "necessari": [58, 77, 86, 100, 107], "delai": [58, 77, 107], "still": [58, 77, 85, 92, 100, 105], "benefici": 58, "behav": [58, 65, 66, 67, 77], "rmsnorm": [59, 104], "queri": 60, "whatev": [60, 77], "underli": [60, 77, 85], "avail": [60, 91, 107], "flash": 60, "greater": [60, 77], "identifi": [60, 77, 86, 92], "versu": 60, "effici": [60, 77, 100, 107], "def": [60, 76, 77, 92, 100, 102, 107], "scale_factor": 60, "math": [60, 102], "els": [60, 65, 66, 67, 73, 77], "attn_bia": 60, "temp_mask": 60, "tril": [60, 77], "diagon": [60, 77], "masked_fill_": [60, 77], "logical_not": [60, 77], "attn_weight": 60, "transpos": [60, 77], "subject": [60, 77], "alwai": [60, 77, 91, 100], "accord": [60, 77, 92], "To": [60, 77, 97, 101, 104, 107], "disabl": [60, 71], "sure": 60, "mymodel": 60, "__init__": [60, 102, 107], "super": [60, 102, 107], "There": [60, 77, 100, 107], "flashattent": 60, "faster": [60, 107], "parallel": 60, "partit": 60, "memori": [60, 65, 66, 67, 77, 100, 107], "match": [60, 65, 66, 67, 77], "formul": 60, "kernel": [60, 77, 107], "cuda": [60, 65, 66, 67, 77], "backend": [60, 85, 94, 95], "attempt": [60, 107], "most": [60, 77, 100, 105, 106, 107], "fine": [60, 65, 66, 67, 77], "grain": 60, "context": [60, 65, 66, 67, 77, 100], "manag": [60, 77], "prefer": 60, "mechan": [60, 107], "sdpa_kernel": 60, "enable_flash_sdp": 60, "global": [60, 71, 85], "enable_mem_efficient_sdp": 60, "enable_math_sdp": 60, "fuse": [60, 65, 66, 67, 107], "event": [60, 77], "reason": [60, 91, 105, 106], "cannot": [60, 77], "due": [60, 65, 66, 67, 77], "natur": 60, "float64": [60, 65, 66, 67, 77], "numerical_accuraci": 60, "circumst": 60, "cudnn": 60, "nondeterminist": [60, 77], "undesir": 60, "try": [60, 65, 66, 67, 107], "potenti": [60, 77], "cost": [60, 77], "ev": 60, "part": 60, "score": 60, "triangular": 60, "form": [60, 107], "causalbia": 60, "thrown": [60, 77], "32": [60, 77], "sdp_kernel": 60, "enable_math": 60, "silu": [62, 104], "desir": [63, 77, 107], "cast": [63, 77], "overflow": [63, 107], "mup": [64, 65, 66, 67, 71], "adam": [64, 66, 69, 71, 104], "adamw": [64, 69, 104], "sgd": [64, 70, 77, 104], "box": [64, 101, 104, 107], "scaled_paramet": [64, 104], "finer": 64, "downstream": 64, "lr": [64, 65, 66, 67, 68, 69, 70, 71, 77], "param": [65, 66, 67, 68, 69, 70, 71, 77], "dict": [65, 66, 67, 71, 73, 77, 92, 96, 101], "001": [65, 66, 67], "weight_decai": [65, 66, 67, 71], "independent_weight_decai": [65, 66, 67, 71], "allow_non_unit_scaling_param": [65, 66, 67, 71], "110mm": [65, 66, 67], "4pt": [65, 66, 67], "textbf": [65, 66, 67], "beta_1": [65, 66], "beta_2": [65, 66], "theta_0": [65, 66, 67], "theta": [65, 66, 67], "hspace": [65, 66, 67], "13mm": [65, 66, 67], "lambda": [65, 66, 67, 77], "decai": [65, 66, 67, 71], "textit": [65, 66, 67], "amsgrad": [65, 66], "maxim": [65, 66, 67], "m_0": [65, 66], "leftarrow": [65, 66, 67], "moment": [65, 66], "v_0": [65, 66], "widehat": [65, 66], "ex": [65, 66, 67], "5mm": [65, 66, 67], "10mm": [65, 66, 67], "g_t": [65, 66, 67], "nabla_": [65, 66, 67], "f_t": [65, 66, 67], "theta_": [65, 66, 67], "neq": [65, 67], "m_t": [65, 66], "m_": [65, 66], "v_t": [65, 66], "v_": [65, 66], "2_t": [65, 66], "big": [65, 66], "theta_t": [65, 66, 67], "bf": [65, 66, 67], "further": [65, 66, 77, 106, 107], "refer": [65, 66, 77, 104, 107], "rate": [65, 66, 67, 71, 106], "yet": [65, 66, 77, 92, 104, 106, 107], "our": [65, 66, 92, 94, 100, 104, 105, 106, 107], "captur": [65, 66, 95, 97], "coeffici": [65, 66], "999": [65, 66, 102, 107], "term": [65, 66, 77, 107], "l2": [65, 67], "penalti": [65, 67], "whether": [65, 66, 67, 77, 92, 100], "variant": [65, 66], "converg": [65, 66], "foreach": [65, 66, 67], "loop": [65, 66, 67], "sinc": [65, 66, 67, 77], "usual": [65, 66, 67, 77, 107], "significantli": [65, 66, 67, 107], "sizeof": [65, 66, 67], "peak": [65, 66, 67], "tensorlist": [65, 66, 67], "prohibit": [65, 66, 67], "fewer": [65, 66, 67], "through": [65, 66, 67, 77, 100, 101, 105, 106, 107], "switch": [65, 66, 67], "flag": [65, 66, 67], "minim": [65, 66, 67], "safe": [65, 66, 77], "impair": [65, 66, 67], "ungraph": [65, 66], "leav": [65, 66, 67, 77, 107], "occur": [65, 66, 67, 77, 107], "step": [65, 66, 67, 71, 77, 107], "float32": [65, 66, 67, 77], "bfloat16": [65, 66, 67, 77], "add_param_group": [65, 66, 67], "param_group": [65, 66, 67, 71], "tune": [65, 66, 67, 107], "frozen": [65, 66, 67], "made": [65, 66, 67, 100], "trainabl": [65, 66, 67], "progress": [65, 66, 67, 77], "load_state_dict": [65, 66, 67, 77], "state_dict": [65, 66, 67], "load": [65, 66, 67, 77], "state": [65, 66, 67, 77], "register_load_state_dict_post_hook": [65, 66, 67], "hook": [65, 66, 67, 77, 100], "removablehandl": [65, 66, 67], "signatur": [65, 66, 67, 76, 77, 100], "fire": [65, 66, 67], "alreadi": [65, 66, 67, 77], "handl": [65, 66, 67, 77, 100, 101], "util": [65, 66, 67, 77, 104, 107], "removeablehandl": [65, 66, 67], "register_load_state_dict_pre_hook": [65, 66, 67], "shallow": [65, 66, 67, 73], "copi": [65, 66, 67, 73, 77, 92], "new": [65, 66, 67, 73, 77, 85, 89, 90, 91, 94, 95, 96, 104, 105, 107], "register_state_dict_post_hook": [65, 66, 67], "register_state_dict_pre_hook": [65, 66, 67], "register_step_post_hook": [65, 66, 67], "register_step_pre_hook": [65, 66, 67], "new_arg": [65, 66, 67], "new_kwarg": [65, 66, 67], "hold": [65, 66, 67, 77], "Its": [65, 66, 67], "content": [65, 66, 67, 100], "characterist": [65, 66, 67], "itself": [65, 66, 67, 77], "NOT": [65, 66, 67], "map": [65, 66, 67, 77, 101], "metadata": [65, 66, 67, 75], "associ": [65, 66, 67, 91], "zip": [65, 66, 67], "actual": [65, 66, 67, 77], "without": [65, 66, 67, 77, 85, 102, 105, 107], "verif": [65, 66, 67], "might": [65, 66, 67, 77, 107], "momentum_buff": [65, 66, 67], "01": [65, 66, 67, 77, 107], "closur": [65, 66, 67], "reevalu": [65, 66, 67], "zero_grad": [65, 66, 67], "set_to_non": [65, 66, 67], "reset": [65, 66, 67], "footprint": [65, 66, 67], "modestli": [65, 66, 67], "howev": [65, 66, 67, 77, 92, 94, 106, 107], "tri": [65, 66, 67, 77], "access": [65, 66, 67, 77, 91, 100], "attribut": [65, 66, 67, 72, 76, 77, 100, 101], "guarante": [65, 66, 67, 77, 101, 105], "did": [65, 66, 67], "receiv": [65, 66, 67, 107], "altogeth": [65, 66, 67], "decoupl": 66, "mu": [67, 77], "momentum": 67, "dampen": 67, "nesterov": 67, "15mm": 67, "_t": 67, "g_": 67, "formula": [67, 100], "deep": [67, 106, 107], "__": 67, "loss_fn": 67, "lr_scale_func": 71, "adam_lr_scale_func": 71, "paramst": 71, "tag": [71, 75], "lr_scale_func_sgd": [71, 104], "overridden": [71, 100], "fail": [71, 77, 107], "rememb": 73, "clear": [73, 101], "od": 73, "fromkei": 73, "move_to_end": 73, "move": [73, 77, 85, 94], "keyerror": 73, "pop": 73, "found": [73, 104, 106, 107], "popitem": 73, "pair": [73, 77, 99], "lifo": 73, "fifo": 73, "setdefault": 73, "present": [73, 77, 96, 104, 107], "extra": [75, 107], "implicitli": 75, "proto": 76, "meth": 76, "Such": 76, "primarili": 76, "static": [76, 100], "checker": 76, "recogn": 76, "structur": [76, 100, 101], "subtyp": 76, "duck": 76, "func": [76, 100], "pep": 76, "544": 76, "decor": [76, 100], "runtime_check": 76, "act": 76, "simpl": [76, 77], "mind": 76, "runtim": 76, "presenc": 76, "genproto": 76, "conjug": 77, "conj": 77, "matric": 77, "real": 77, "mh": 77, "revers": 77, "permut": 77, "throw": 77, "releas": [77, 85, 104, 107], "mt": 77, "arang": 77, "ndim": 77, "abs_": 77, "alia": 77, "absolute_": 77, "aco": 77, "acos_": 77, "acosh": 77, "acosh_": 77, "add_": 77, "addbmm": 77, "batch1": 77, "batch2": 77, "addbmm_": 77, "addcdiv": 77, "tensor1": 77, "tensor2": 77, "addcdiv_": 77, "addcmul": 77, "addcmul_": 77, "addmm": 77, "mat1": 77, "mat2": 77, "addmm_": 77, "addmv": 77, "mat": 77, "vec": 77, "addmv_": 77, "addr": 77, "vec1": 77, "vec2": 77, "addr_": 77, "adjoint": 77, "align_a": 77, "explicit": 77, "align_to": 77, "127": 77, "refine_nam": 77, "img": 77, "scale_channel": 77, "num_channel": 77, "more_img": 77, "video": [77, 104], "agnost": 77, "api": [77, 101, 104, 107], "ellipsi": 77, "expand": [77, 100], "mention": 77, "appear": 77, "string": [77, 102], "unment": 77, "named_tensor": 77, "front": 77, "keep": [77, 107], "rest": 77, "allclos": 77, "rtol": [77, 87], "atol": 77, "08": 77, "equal_nan": 77, "amax": 77, "amin": 77, "aminmax": 77, "angl": 77, "apply_": 77, "cpu": 77, "arcco": 77, "arccos_": 77, "arccosh": 77, "arccosh_": 77, "arcsin": 77, "arcsin_": 77, "arcsinh": 77, "arcsinh_": 77, "arctan": 77, "arctan2": 77, "arctan2_": 77, "atan2_": 77, "arctan_": 77, "arctanh": 77, "arctanh_": 77, "argmax": 77, "argmin": 77, "argsort": 77, "descend": [77, 101], "argwher": 77, "as_strid": 77, "stride": 77, "storage_offset": 77, "as_strided_": 77, "as_strided_scatt": 77, "src": 77, "as_subclass": 77, "cl": 77, "pointer": 77, "stai": [77, 107], "attach": 77, "subclass": [77, 100], "asin": 77, "asin_": 77, "asinh": 77, "asinh_": 77, "atan": 77, "atan2": 77, "atan_": 77, "atanh": 77, "atanh_": 77, "retain_graph": 77, "create_graph": 77, "wrt": 77, "addition": 77, "accumul": 77, "stream": 77, "semant": [77, 101], "leaf": 77, "grad_fn": 77, "strictli": 77, "reli": [77, 92], "github": [77, 104, 107], "com": [77, 104, 107], "pull": 77, "60521": 77, "issuecom": 77, "867061780": 77, "omit": 77, "freed": 77, "nearli": 77, "much": [77, 92, 107], "deriv": [77, 106, 107], "were": [77, 100], "baddbmm": 77, "baddbmm_": 77, "texttt": 77, "bernoulli_": 77, "fill": 77, "locat": 77, "integr": 77, "draw": 77, "binari": 77, "th": 77, "_tensor": 77, "memory_format": [77, 101], "preserve_format": 77, "bincount": 77, "minlength": 77, "bitwise_and": 77, "bitwise_and_": 77, "bitwise_left_shift": 77, "bitwise_left_shift_": 77, "bitwise_not": 77, "bitwise_not_": 77, "bitwise_or": 77, "bitwise_or_": 77, "bitwise_right_shift": 77, "bitwise_right_shift_": 77, "bitwise_xor": 77, "bitwise_xor_": 77, "bmm": 77, "broadcast_to": 77, "byte": 77, "uint8": 77, "cauchy_": 77, "median": 77, "drawn": 77, "cauchi": 77, "dfrac": 77, "cdoubl": 77, "complex128": 77, "ceil": 77, "ceil_": 77, "cfloat": 77, "complex64": 77, "chalf": 77, "complex32": 77, "char": 77, "int8": 77, "choleski": 77, "cholesky_invers": 77, "cholesky_solv": 77, "input2": 77, "clamp": 77, "clamp_": 77, "clip": [77, 107], "clip_": 77, "clone": [77, 100, 104, 107], "coalesc": 77, "uncoalesc": 77, "coo": 77, "col_indic": 77, "csr": 77, "sparse_csr": 77, "nnz": 77, "int32": 77, "mkl": 77, "routin": 77, "avoid": 77, "downcast": 77, "lose": 77, "ey": 77, "to_sparse_csr": 77, "conj_phys": 77, "conj_physical_": 77, "contigu": 77, "contiguous_format": 77, "copy_": 77, "non_block": 77, "resid": 77, "gpu": 77, "asynchron": 77, "host": 77, "copysign": 77, "copysign_": 77, "corrcoef": 77, "cos_": 77, "cosh": 77, "cosh_": 77, "count_nonzero": 77, "cov": 77, "correct": [77, 89, 90, 100, 106, 107], "fweight": 77, "aweight": 77, "crow_indic": 77, "compress": 77, "destin": 77, "pin": 77, "cummax": 77, "cummin": 77, "cumprod": 77, "cumprod_": 77, "cumsum": 77, "cumsum_": 77, "data_ptr": 77, "address": [77, 85, 107], "deg2rad": 77, "deg2rad_": 77, "dense_dim": 77, "dens": 77, "len": 77, "sparse_dim": 77, "hybrid": 77, "dequant": 77, "quantiz": 77, "det": 77, "detach": 77, "never": [77, 87], "affect": 77, "share": [77, 100, 107], "storag": [77, 100], "trigger": 77, "detach_": 77, "diag": 77, "diag_emb": 77, "offset": 77, "dim1": 77, "dim2": 77, "diagflat": 77, "diagonal_scatt": 77, "diff": 77, "digamma": 77, "digamma_": 77, "dim_ord": 77, "physic": 77, "laid": 77, "outermost": 77, "innermost": 77, "channels_last": 77, "dist": 77, "div": 77, "rounding_mod": 77, "div_": 77, "divide_": 77, "doubl": [77, 100], "dsplit": 77, "split_size_or_sect": 77, "element_s": 77, "individu": 77, "eq": 77, "eq_": 77, "erf": 77, "erf_": 77, "erfc": 77, "erfc_": 77, "erfinv": 77, "erfinv_": 77, "exp2": 77, "exp2_": 77, "exp_": 77, "alloc": 77, "As": [77, 90, 107], "especi": 77, "write": 77, "expand_a": 77, "expm1": 77, "expm1_": 77, "exponential_": 77, "lambd": 77, "pdf": 77, "densiti": 77, "theori": [77, 107], "exponenti": 77, "interv": 77, "impli": 77, "fill_": 77, "fill_diagonal_": 77, "fill_valu": 77, "wrap": [77, 95, 101, 102, 107], "main": 77, "tall": 77, "fix_": 77, "flatten": 77, "start_dim": 77, "end_dim": 77, "flip": 77, "fliplr": 77, "flipud": 77, "float_pow": 77, "expon": 77, "float_power_": 77, "floor": 77, "floor_": 77, "floor_divid": 77, "floor_divide_": 77, "fmax": 77, "fmin": 77, "fmod": 77, "divisor": 77, "fmod_": 77, "frac_": 77, "frexp": 77, "mantissa": 77, "gather": 77, "gcd": 77, "gcd_": 77, "ge": 77, "ge_": 77, "geometric_": 77, "trial": 77, "success": 77, "henc": [77, 92], "wherea": 77, "geqrf": 77, "ger": 77, "get_devic": 77, "ordin": 77, "greater_": 77, "greater_equ": 77, "greater_equal_": 77, "gt": [77, 106], "gt_": 77, "half": 77, "hardshrink": 77, "has_nam": 77, "heavisid": 77, "heaviside_": 77, "histc": 77, "histogram": 77, "hsplit": 77, "hypot": 77, "hypot_": 77, "i0": 77, "i0_": 77, "igamma": 77, "igamma_": 77, "igammac": 77, "igammac_": 77, "imaginari": 77, "3100": 77, "3553j": 77, "5445": 77, "7896j": 77, "6492": 77, "0633j": 77, "0638": 77, "8119j": 77, "3553": 77, "7896": 77, "0633": 77, "8119": 77, "index_add": 77, "index_add_": [77, 100], "subtract": 77, "index_copi": 77, "index_copy_": 77, "duplic": 77, "index_fil": 77, "index_fill_": 77, "index_put": 77, "index_put_": 77, "put": [77, 107], "express": [77, 92, 96, 106], "undefin": [77, 100], "index_reduce_": 77, "include_self": 77, "prod": 77, "identit": 77, "11": 77, "12": 77, "44": 77, "72": 77, "14": 77, "22": 77, "36": 77, "index_select": 77, "inner": [77, 92], "int_repr": 77, "uint8_t": 77, "is_coalesc": 77, "is_complex": 77, "is_conj": 77, "is_contigu": 77, "is_cpu": 77, "is_cuda": 77, "is_floating_point": 77, "is_infer": 77, "is_ipu": 77, "is_leaf": 77, "convent": [77, 101], "popul": [77, 101], "retain_grad": 77, "engin": [77, 100], "requires_grad_": [77, 91, 102, 107], "is_meta": 77, "meta": [77, 91], "carri": 77, "is_mp": 77, "mp": 77, "is_neg": 77, "neg": 77, "is_pin": 77, "is_quant": 77, "is_set_to": 77, "exact": [77, 107], "is_shar": 77, "is_sign": 77, "sign": [77, 107], "is_spars": 77, "is_sparse_csr": 77, "is_xla": 77, "xla": 77, "is_xpu": 77, "xpu": 77, "isclos": 77, "isfinit": 77, "isinf": 77, "isnan": 77, "isneginf": 77, "isposinf": 77, "isreal": 77, "istft": 77, "n_fft": 77, "hop_length": 77, "win_length": 77, "window": 77, "center": 77, "onesid": 77, "return_complex": 77, "tolist": 77, "items": 77, "kron": 77, "kthvalu": 77, "lcm": 77, "lcm_": 77, "ldexp": 77, "ldexp_": 77, "le": 77, "le_": 77, "lerp": 77, "lerp_": 77, "less": 77, "lt": 77, "less_": 77, "less_equ": 77, "less_equal_": 77, "lgamma": 77, "lgamma_": 77, "log10": 77, "log10_": 77, "log1p": 77, "log1p_": 77, "log2": 77, "log2_": 77, "log_": 77, "log_normal_": 77, "parameter": 77, "ln": 77, "logaddexp": 77, "logaddexp2": 77, "logcumsumexp": 77, "logdet": 77, "logical_and": 77, "logical_and_": 77, "logical_not_": 77, "logical_or": 77, "logical_or_": 77, "logical_xor": 77, "logical_xor_": 77, "logit_": 77, "logsumexp": 77, "lt_": 77, "lu": 77, "pivot": 77, "get_info": 77, "lu_solv": 77, "lu_data": 77, "lu_pivot": 77, "map_": 77, "masked_fil": 77, "booltensor": 77, "masked_scatt": 77, "masked_scatter_": 77, "continu": 77, "occurr": 77, "mani": [77, 92, 100, 106, 107], "masked_select": 77, "matmul": [77, 89, 90, 104], "matrix_exp": 77, "matrix_pow": 77, "linalg": 77, "module_load": 77, "get_swap_module_params_on_convers": 77, "buffer": 77, "remap": 77, "swap_tensor": 77, "moveaxi": 77, "movedim": 77, "msort": 77, "mul": 77, "mul_": 77, "multinomi": 77, "num_sampl": 77, "multiply_": 77, "mv": 77, "mvlgamma": 77, "mvlgamma_": 77, "idx": [77, 100], "unnam": 77, "charact": [77, 106], "underscor": 77, "variabl": [77, 100], "nan_to_num": 77, "nan": 77, "posinf": 77, "neginf": 77, "nan_to_num_": 77, "nanmean": 77, "nanmedian": 77, "nanquantil": 77, "q": [77, 106], "nansum": 77, "narrow": 77, "narrow_copi": 77, "nbyte": 77, "consum": 77, "ndimens": 77, "ne": 77, "ne_": 77, "neg_": 77, "negative_": 77, "nelement": 77, "new_empti": 77, "pin_memori": 77, "uniniti": 77, "record": [77, 91, 100, 101], "would": [77, 100], "8182e": 77, "5765e": 77, "41": 77, "0545e": 77, "0949e": 77, "4842e": 77, "0000e": 77, "00": 77, "new_empty_strid": 77, "new_ful": 77, "141592": 77, "1416": 77, "new_on": 77, "new_tensor": 77, "want": [77, 105], "numpi": [77, 100], "arrai": 77, "from_numpi": 77, "read": [77, 107], "array_lik": 77, "new_zero": 77, "nextaft": 77, "nextafter_": 77, "nonzero": 77, "nonzero_stat": 77, "count": 77, "truncat": 77, "smaller": [77, 107], "invalid": 77, "input_tensor": 77, "static_s": 77, "rank": 77, "fro": 77, "normal_": 77, "not_equ": 77, "not_equal_": 77, "forc": 77, "ndarrai": 77, "convers": [77, 107], "reflect": [77, 102, 107], "vice": 77, "versa": 77, "resolve_conj": 77, "resolve_neg": 77, "isn": [77, 106], "won": 77, "shorthand": 77, "orgqr": 77, "ormqr": 77, "input3": 77, "outer": 77, "pinvers": 77, "polygamma": 77, "polygamma_": 77, "pow": 77, "pow_": 77, "put_": 77, "q_per_channel_axi": 77, "q_per_channel_scal": 77, "q_per_channel_zero_point": 77, "zero_point": 77, "q_scale": 77, "q_zero_point": 77, "qr": 77, "qscheme": 77, "qtensor": 77, "quantil": 77, "rad2deg": 77, "rad2deg_": 77, "discret": 77, "bound": 77, "53": 77, "ravel": 77, "reciproc": 77, "reciprocal_": 77, "record_stream": 77, "mark": [77, 100], "dealloc": [77, 101], "reus": 77, "until": [77, 107], "queu": 77, "complet": [77, 88, 105], "cach": [77, 85, 94], "awar": 77, "correctli": 77, "life": 77, "cycl": 77, "But": [77, 106], "unexpectedli": 77, "let": 77, "know": [77, 100, 106], "suitabl": 77, "side": 77, "abl": [77, 92, 107], "think": [77, 106], "carefulli": 77, "safeti": 77, "These": [77, 91, 107], "analog": 77, "tradeoff": 77, "gc": 77, "situat": 77, "lifetim": 77, "poll": 77, "race": 77, "creation": 77, "sync": 77, "back": 77, "suffici": [77, 92, 107], "realloc": 77, "done": [77, 92], "counterintuit": 77, "old": [77, 107], "becaus": [77, 107], "wait": 77, "concret": [77, 101], "s0": 77, "s1": 77, "wait_stream": 77, "some_comm_op": 77, "synchron": 77, "del": 77, "decid": 77, "immedi": 77, "wouldn": 77, "finish": 77, "profil": 77, "chrome": 77, "trace": [77, 92], "produc": [77, 107], "export_chrome_trac": 77, "earli": 77, "block": [77, 106, 107], "overlap": 77, "commun": 77, "late": 77, "live": 77, "guidanc": 77, "fsdp": 77, "cudacachingalloc": 77, "refin": 77, "special": [77, 92, 105], "renam": 77, "lift": 77, "coexist": 77, "nice": 77, "greedili": 77, "named_img": 77, "register_hook": 77, "execut": [77, 85, 94, 101], "register_post_accumulate_grad_hook": 77, "unless": [77, 100], "enable_grad": 77, "0100": 77, "0200": 77, "0300": 77, "remaind": 77, "remainder_": 77, "rename_map": 77, "position": 77, "drop": [77, 107], "One": 77, "renamed_img": 77, "height": 77, "width": 77, "rename_": 77, "maxnorm": 77, "renorm_": 77, "repeat": [77, 107], "similar": [77, 100, 107], "tile": 77, "repeat_interleav": 77, "output_s": 77, "fact": [77, 107], "tell": 77, "obtain": 77, "dataload": 77, "preprocess": 77, "sai": 77, "saved_weight": 77, "25": 77, "loaded_weight": 77, "5503": 77, "4926": 77, "1158": 77, "8303": 77, "1007": 77, "9853": 77, "2316": 77, "6606": 77, "compat": [77, 101], "reshape_a": 77, "resize_": 77, "resiz": 77, "fit": 77, "preserv": [77, 106, 107], "level": [77, 107], "reinterpret": 77, "unchang": [77, 80], "custom": [77, 92, 100], "set_": 77, "use_deterministic_algorithm": 77, "fill_uninitialized_memori": 77, "go": [77, 100, 106, 107], "unaffect": 77, "resize_as_": 77, "retains_grad": 77, "roll": 77, "rot90": 77, "decim": 77, "round_": 77, "rsqrt": 77, "rsqrt_": 77, "scatter": 77, "scatter_": 77, "manner": 77, "moreov": 77, "inclus": 77, "uniqu": 77, "pick": 77, "arbitrarili": 77, "propag": 77, "scatter_add_": 77, "scatter_reduce_": 77, "23": 77, "4600": 77, "2300": 77, "previou": [77, 87, 92], "scatter_add": 77, "fashion": 77, "scatter_reduc": 77, "select_scatt": 77, "sgn": 77, "sgn_": 77, "share_memory_": 77, "untypedstorag": 77, "short": 77, "int16": 77, "sigmoid_": 77, "sign_": 77, "signbit": 77, "sin": 77, "sin_": 77, "sinc_": 77, "sinh": 77, "sinh_": 77, "slice_scatt": 77, "slogdet": 77, "smm": 77, "sort": [77, 100], "sparse_mask": 77, "filter": 77, "advis": 77, "whose": [77, 102], "nse": 77, "cat": 77, "sparse_coo_tensor": 77, "6550": 77, "2397": 77, "1611": 77, "0779": 77, "2326": 77, "0558": 77, "4711": 77, "9678": 77, "5138": 77, "0411": 77, "9417": 77, "5158": 77, "0793": 77, "0036": 77, "2569": 77, "1055": 77, "sparse_coo": 77, "sparse_resize_": 77, "sparse_resize_and_clear_": 77, "split_siz": 77, "sqrt_": 77, "square_": 77, "squeez": 77, "squeeze_": 77, "sspaddmm": 77, "stft": 77, "pad_mod": 77, "typedstorag": 77, "directli": [77, 100, 105, 107], "untyped_storag": 77, "storage_typ": 77, "jump": 77, "next": [77, 107], "sub": 77, "sub_": 77, "subtract_": 77, "sum_to_s": 77, "svd": 77, "compute_uv": 77, "swapax": 77, "axis0": 77, "axis1": 77, "swapaxes_": 77, "swapdim": 77, "dim0": 77, "swapdims_": 77, "t_": 77, "take_along_dim": 77, "tan": 77, "tan_": 77, "tanh_": 77, "tensor_split": 77, "indices_or_sect": 77, "5044": 77, "0005": 77, "3310": 77, "0584": 77, "cuda0": 77, "to_dens": 77, "masked_grad": 77, "to_mkldnn": 77, "mkldnn": 77, "to_padded_tensor": 77, "to_spars": 77, "sparsedim": 77, "coordin": 77, "blocksiz": 77, "could": [77, 107], "sparse_csc": 77, "sparse_bsr": 77, "sparse_bsc": 77, "bsr": 77, "bsc": 77, "runtimeerror": [77, 100], "except": 77, "evenli": 77, "csc": 77, "minu": 77, "divis": [77, 107], "sparsecsr": 77, "to_sparse_bsc": 77, "row_indic": 77, "ccol_indic": 77, "to_sparse_bsr": 77, "to_sparse_coo": 77, "_nnz": 77, "to_sparse_csc": 77, "nest": [77, 85, 92, 94], "012766935862600803": 77, "5415473580360413": 77, "08909505605697632": 77, "7729271650314331": 77, "topk": 77, "largest": 77, "transpose_": 77, "triangular_solv": 77, "unitriangular": 77, "tril_": 77, "triu": 77, "triu_": 77, "true_divid": 77, "true_divide_": 77, "trunc": 77, "trunc_": 77, "async": 77, "type_a": 77, "unbind": 77, "seq": 77, "unflatten": 77, "unfold": 77, "sizedim": 77, "happen": [77, 100, 106], "uniform_": 77, "return_invers": 77, "return_count": 77, "unique_consecut": 77, "elimin": [77, 107], "consecut": 77, "unsafe_chunk": 77, "unsafe_split": 77, "unsqueez": 77, "unsqueeze_": 77, "vdot": 77, "subspac": 77, "across": 77, "satisfi": 77, "condit": 77, "foral": 77, "unclear": 77, "z": [77, 100, 106], "2nd": 77, "3rd": 77, "proportion": 77, "twice": 77, "met": 77, "overload": 77, "torchscript": [77, 107], "program": [77, 107], "9482": 77, "0310": 77, "4999": 77, "5316": 77, "1520": 77, "7472": 77, "5617": 77, "8649": 77, "4724": 77, "0334": 77, "2976": 77, "8499": 77, "2109": 77, "9913": 77, "9607": 77, "6123": 77, "1064483442": 77, "1124191867": 77, "1069546515": 77, "1089989247": 77, "1105482831": 77, "1061112040": 77, "1057999968": 77, "1084397505": 77, "1071760287": 77, "1123489973": 77, "1097310419": 77, "1084649136": 77, "1101533110": 77, "1073668768": 77, "1082790149": 77, "1088634448": 77, "1000000000": 77, "0047": 77, "0310j": 77, "5316j": 77, "7472j": 77, "8649j": 77, "0334j": 77, "8499j": 77, "9913j": 77, "6123j": 77, "202": 77, "154": 77, "59": 77, "182": 77, "243": 77, "253": 77, "188": 77, "185": 77, "252": 77, "191": 77, "63": 77, "240": 77, "227": 77, "165": 77, "27": 77, "190": 77, "146": 77, "203": 77, "15": 77, "106": 77, "93": 77, "205": 77, "192": 77, "112": 77, "206": 77, "189": 77, "95": 77, "152": 77, "147": 77, "89": 77, "43": 77, "246": 77, "87": 77, "235": [77, 102], "226": 77, "254": 77, "111": 77, "117": 77, "177": [77, 107], "28": 77, "view_a": 77, "vsplit": 77, "xlogi": 77, "xlogy_": 77, "typeguard": 78, "dynamo": 83, "fwd_tensor": 84, "fwd": [84, 107], "bwd": [84, 102, 107], "slightli": [85, 87, 106, 107], "doesn": [85, 92, 94, 95, 97, 100], "unit_scal": [85, 91, 104, 105], "introduc": [85, 106, 107], "compos": 85, "_dynamo": [85, 94, 95], "optimis": [85, 94, 95, 104], "thereaft": [85, 94], "written": 85, "successfulli": 85, "rather": 85, "graphmodul": [85, 94, 101], "simulate_fp8": [85, 91, 92, 104, 105], "node": [86, 87, 88, 91, 96, 101], "suppli": [86, 87, 89, 92, 94, 100, 107], "pruned_graph": [86, 87], "52587890625e": 87, "negligibli": 87, "toler": 87, "signific": [87, 107], "onc": [88, 100], "fwd_format": 89, "bwd_format": 89, "torchdynamo": [89, 90, 92, 94, 95, 97], "scaled_dot_product_attent": [89, 90, 104], "inspect": [89, 90], "fp32": [89, 90, 107], "speedup": [89, 90, 107], "variou": [89, 91], "fp8": [90, 104, 107], "literatur": 90, "noun": 90, "et": 90, "al": 90, "2022": 90, "micikeviciu": 90, "e4": [90, 107], "e5": [90, 107], "analysi": [91, 92, 102, 103, 104, 107], "prune_non_float_tensor": [91, 104], "prune_same_scale_tensor": [91, 104], "tend": [91, 107], "procedur": 92, "recurs": [92, 94, 95, 97, 101, 102], "build": [92, 107], "fundament": 92, "proce": 92, "five": 92, "stage": 92, "identif": 92, "compar": 92, "guid": [92, 104], "unconstrain": 92, "proof": 92, "initialis": [92, 106, 107], "approach": [92, 94, 107], "compil": [92, 94, 95, 104, 105, 107], "own": [92, 94], "system": [92, 94], "easi": [92, 94], "interoper": [92, 94], "definit": 92, "basic": [92, 107], "told": 92, "explicitli": [92, 106], "substitut": 92, "new_gelu": 92, "test": [92, 105, 106, 107], "said": 92, "anticip": 92, "ultim": 92, "alon": 92, "prioriti": 92, "non_recurse_funct": 94, "graph_modul": 94, "backend_1": 94, "backend_2": 94, "_modul": 94, "torch_nn_modules_to_user_modul": [95, 104], "patch": 95, "target_fn": 96, "keep_type_expr": 96, "accompani": [96, 106], "retain": [96, 106], "mod": 97, "trivial_subclass": 97, "develop": [98, 107], "dataclass": 99, "scalepair": [100, 104], "ctx": 100, "functionctx": 100, "vjp": 100, "needs_input_grad": 100, "jvp": 100, "grad_input": 100, "got": 100, "mark_dirti": 100, "setup_context": 100, "matter": 100, "torch_doctest_autograd": 100, "staticmethod": 100, "x_npy": 100, "once_differenti": 100, "grad_output": 100, "lead": [100, 107], "wrong": 100, "mark_non_differenti": 100, "save_for_backward": 100, "g1": 100, "g2": 100, "saved_tensor": 100, "zeros_lik": 100, "oppos": 100, "leak": 100, "saved_tensors_hook": 100, "intermediari": 100, "neither": 100, "nor": 100, "recomput": 100, "tutori": [100, 107], "weren": 100, "grad_out": 100, "gx": 100, "gy": 100, "gz": 100, "save_for_forward": 100, "x_t": 100, "y_t": 100, "fwad": 100, "dual_level": 100, "a_dual": 100, "make_du": 100, "set_materialize_grad": 100, "materi": 100, "simplefunc": 100, "No": 100, "induc": 100, "insid": 100, "vmap": 100, "info": 100, "in_dim": 100, "underneath": 100, "generate_vmap_rul": 100, "choos": 100, "out_dim": 100, "instrument": 101, "boxed_run": 101, "args_list": 101, "interpret": 101, "promptli": 101, "call_funct": 101, "invoc": 101, "call_method": 101, "opoverload": 101, "call_modul": 101, "fetch_args_kwargs_from_env": 101, "fetch": 101, "environ": 101, "fetch_attr": 101, "hierarchi": 101, "qualifi": 101, "get_attr": 101, "Will": 101, "map_nodes_to_valu": 101, "belong": 101, "report": 101, "realli": 101, "referenc": 101, "placehold": 101, "tracer": 101, "target_to_funct": 101, "initial_env": 101, "enable_io_process": 101, "partial": 101, "process_input": 101, "process_output": 101, "run_nod": 101, "recurse_modul": 102, "syntax_highlight": 102, "autowrap_modul": 102, "einop": 102, "home": 102, "runner": 102, "lib": 102, "site": 102, "packag": 102, "py": 102, "autowrap_funct": 102, "dummi": 102, "union": 102, "fed": [102, 107], "plain": 102, "toggl": 102, "behavour": 102, "moduletyp": 102, "fc1": 102, "fc2": 102, "236": 102, "fc1_weight": 102, "018": [102, 107], "54": 102, "fc1_bia": 102, "0182": 102, "51": 102, "_c": [102, 107], "_nn": [102, 107], "578": [102, 107], "204": [102, 107], "337": 102, "288": 102, "fc2_weight": 102, "00902": [102, 107], "13": 102, "fc2_bia": 102, "00904": 102, "31": 102, "linear_1": [102, 107], "welcom": 104, "design": [104, 106, 107], "facilit": 104, "outlin": [104, 107], "icml": [104, 107], "notebook": [104, 106, 107], "nanogpt": 104, "git": [104, 107], "graphcor": [104, 106, 107], "research": [104, 107], "broad": 104, "overview": [104, 107], "occasion": [104, 107], "bug": [104, 105, 107], "keen": [104, 105, 107], "encount": [104, 107], "fork": [104, 107], "repo": [104, 107], "instruct": [104, 107], "consider": 104, "blog": 104, "almost": 104, "depthmodulelist": 104, "depthsequenti": 104, "linearreadout": 104, "mhsa": 104, "transformerdecod": 104, "transformerlay": 104, "graph_to_datafram": 104, "apply_constraint": 104, "format_to_tupl": 104, "tuple_to_format": 104, "cross_entropi": [104, 107], "linear_readout": 104, "mse_loss": 104, "residual_appli": 104, "rms_norm": 104, "silu_glu": 104, "lr_scale_for_depth": 104, "lr_scale_func_adam": 104, "scale_bwd": 104, "scale_fwd": 104, "prune_selected_nod": 104, "simulate_format": 104, "apply_transform": 104, "patch_to_expand_modul": 104, "replace_node_with_funct": 104, "analyse_modul": [104, 107], "scaletrack": 104, "scaletrackinginterpret": 104, "logarithmic_interpol": 104, "scale_elementwis": 104, "despit": 105, "best": [105, 107], "effort": 105, "free": 105, "assist": 105, "anyon": 105, "issu": [105, 107], "coverag": 105, "ve": [105, 106], "focuss": 105, "difficulti": 105, "although": [105, 107], "suspect": 105, "haven": 105, "exhaust": 105, "encourag": 105, "touch": 105, "tl": 106, "dr": 106, "good": [106, 107], "thing": 106, "roughli": [106, 107], "behaviour": [106, 107], "satur": 106, "stabl": 106, "prime": 106, "color": 106, "green": 106, "insuffici": 106, "red": 106, "ll": 106, "explain": 106, "dynam": [106, 107], "wors": 106, "condens": 106, "summari": 106, "sim": 106, "infti": 106, "flat": 106, "uncorrel": 106, "spike": 106, "assumpt": [106, 107], "companion": 106, "find": [106, 107], "propos": 106, "autoregress": 106, "languag": 106, "shakespear": 106, "saw": 106, "sweep": 106, "unfortun": 106, "tini": 106, "shakespar": 106, "bert": 106, "explan": 106, "intrigu": 106, "presum": 106, "accid": 106, "turn": 106, "solut": 106, "bad": 106, "reproduc": 106, "inde": 106, "care": [106, 107], "principl": 106, "underpin": 106, "far": [106, 107], "question": 106, "interest": 106, "With": 106, "thank": 106, "charli": 106, "blake": 106, "feedback": 106, "douglaso": 106, "ai": 106, "cover": 107, "brief": 107, "discuss": 107, "paradigm": 107, "aim": 107, "involv": 107, "scratch": 107, "advantag": 107, "great": 107, "headroom": 107, "grow": 107, "shrink": 107, "underflow": 107, "drift": 107, "fp16": 107, "decreas": 107, "treatment": 107, "motiv": 107, "sens": 107, "bf16": 107, "veri": 107, "larg": 107, "3e": 107, "38": 107, "45": 107, "60": 107, "000": 107, "6e": 107, "speed": 107, "scope": 107, "introduct": 107, "tricki": 107, "easier": 107, "breakdown": 107, "alongsid": 107, "unscaledmlp": 107, "linear_2": 107, "annotated_cod": 107, "linear_1_weight": 107, "83": 107, "linear_1_bia": 107, "84": 107, "322": 107, "289": 107, "linear_2_weight": 107, "48": 107, "linear_2_bia": 107, "00894": 107, "198": 107, "firstli": 107, "decompos": 107, "secondli": 107, "fwd_scale": 107, "bwd_scale": 107, "enough": 107, "unscal": 107, "scaledmlp": 107, "716": 107, "729": 107, "707": 107, "706": 107, "693": 107, "03": 107, "979": 107, "art": 107, "aris": 107, "clearli": 107, "explod": 107, "degrad": 107, "steadili": 107, "meet": 107, "concern": 107, "merit": 107, "investig": 107, "attain": 107, "substanti": 107, "push": 107, "themselv": 107, "solv": 107, "outsid": 107, "separ": 107, "residuallay": 107, "contrast": 107, "50": 107, "down": 107, "emploi": 107, "trick": 107, "comprehens": 107, "scenario": 107, "arriv": 107, "fan_in": 107, "fan_out": 107, "grad_weight_scal": 107, "grad_bias_scal": 107, "ideal": 107, "compromis": 107, "eager": 107, "trip": 107, "fortun": 107, "hi": 107, "overhead": 107, "fusion": 107, "answer": 107, "jit": 107, "script": 107, "rectifi": 107, "flexibl": 107, "unit_scaled_funct": 107, "unitscaledmodul": 107, "incur": 107, "naiv": 107, "benchmark": 107, "thorough": 107, "strongli": 107, "latest": 107, "recent": 107, "upgrad": 107, "preview": 107, "nightli": 107}, "objects": {"": [[3, 0, 0, "-", "unit_scaling"]], "unit_scaling": [[4, 1, 1, "", "CrossEntropyLoss"], [5, 1, 1, "", "DepthModuleList"], [6, 1, 1, "", "DepthSequential"], [7, 1, 1, "", "Dropout"], [8, 1, 1, "", "Embedding"], [9, 1, 1, "", "GELU"], [10, 1, 1, "", "LayerNorm"], [11, 1, 1, "", "Linear"], [12, 1, 1, "", "LinearReadout"], [13, 1, 1, "", "MHSA"], [14, 1, 1, "", "MLP"], [15, 4, 1, "", "Parameter"], [16, 1, 1, "", "RMSNorm"], [17, 1, 1, "", "SiLU"], [18, 1, 1, "", "Softmax"], [19, 1, 1, "", "TransformerDecoder"], [20, 1, 1, "", "TransformerLayer"], [21, 0, 0, "-", "analysis"], [26, 0, 0, "-", "constraints"], [35, 0, 0, "-", "core"], [41, 0, 0, "-", "formats"], [45, 0, 0, "-", "functional"], [64, 0, 0, "-", "optim"], [72, 0, 0, "-", "parameter"], [79, 0, 0, "-", "scale"], [82, 4, 1, "", "transformer_residual_scaling_rule"], [83, 0, 0, "-", "transforms"], [98, 0, 0, "-", "utils"], [103, 4, 1, "", "visualiser"]], "unit_scaling.DepthModuleList": [[5, 2, 1, "", "append"], [5, 2, 1, "", "extend"], [5, 2, 1, "", "insert"]], "unit_scaling.DepthSequential": [[6, 2, 1, "", "append"]], "unit_scaling.Embedding": [[8, 2, 1, "", "from_pretrained"], [8, 3, 1, "", "weight"]], "unit_scaling.LayerNorm": [[10, 3, 1, "", "bias"], [10, 3, 1, "", "weight"]], "unit_scaling.Linear": [[11, 3, 1, "", "bias"], [11, 3, 1, "", "weight"]], "unit_scaling.LinearReadout": [[12, 3, 1, "", "bias"], [12, 3, 1, "", "weight"]], "unit_scaling.RMSNorm": [[16, 3, 1, "", "weight"]], "unit_scaling.TransformerDecoder": [[19, 2, 1, "", "append"]], "unit_scaling.analysis": [[22, 4, 1, "", "example_batch"], [23, 4, 1, "", "graph_to_dataframe"], [24, 4, 1, "", "plot"], [25, 4, 1, "", "visualiser"]], "unit_scaling.constraints": [[27, 4, 1, "", "amean"], [28, 4, 1, "", "apply_constraint"], [29, 4, 1, "", "gmean"], [30, 4, 1, "", "hmean"], [31, 4, 1, "", "to_grad_input_scale"], [32, 4, 1, "", "to_left_grad_scale"], [33, 4, 1, "", "to_output_scale"], [34, 4, 1, "", "to_right_grad_scale"]], "unit_scaling.core": [[36, 0, 0, "-", "functional"]], "unit_scaling.core.functional": [[37, 4, 1, "", "logarithmic_interpolation"], [38, 4, 1, "", "rms"], [39, 4, 1, "", "scale_elementwise"], [40, 4, 1, "", "transformer_residual_scaling_rule"]], "unit_scaling.formats": [[42, 1, 1, "", "FPFormat"], [43, 4, 1, "", "format_to_tuple"], [44, 4, 1, "", "tuple_to_format"]], "unit_scaling.formats.FPFormat": [[42, 5, 1, "", "bits"], [42, 5, 1, "", "max_absolute_value"], [42, 5, 1, "", "min_absolute_normal"], [42, 5, 1, "", "min_absolute_subnormal"], [42, 2, 1, "", "quantise"], [42, 2, 1, "", "quantise_bwd"], [42, 2, 1, "", "quantise_fwd"]], "unit_scaling.functional": [[46, 4, 1, "", "add"], [47, 4, 1, "", "cross_entropy"], [48, 4, 1, "", "dropout"], [49, 4, 1, "", "embedding"], [50, 4, 1, "", "gelu"], [51, 4, 1, "", "layer_norm"], [52, 4, 1, "", "linear"], [53, 4, 1, "", "linear_readout"], [54, 4, 1, "", "matmul"], [55, 4, 1, "", "mse_loss"], [56, 4, 1, "", "residual_add"], [57, 4, 1, "", "residual_apply"], [58, 4, 1, "", "residual_split"], [59, 4, 1, "", "rms_norm"], [60, 4, 1, "", "scaled_dot_product_attention"], [61, 4, 1, "", "silu"], [62, 4, 1, "", "silu_glu"], [63, 4, 1, "", "softmax"]], "unit_scaling.optim": [[65, 1, 1, "", "Adam"], [66, 1, 1, "", "AdamW"], [67, 1, 1, "", "SGD"], [68, 4, 1, "", "lr_scale_for_depth"], [69, 4, 1, "", "lr_scale_func_adam"], [70, 4, 1, "", "lr_scale_func_sgd"], [71, 4, 1, "", "scaled_parameters"]], "unit_scaling.optim.Adam": [[65, 2, 1, "", "add_param_group"], [65, 2, 1, "", "load_state_dict"], [65, 2, 1, "", "register_load_state_dict_post_hook"], [65, 2, 1, "", "register_load_state_dict_pre_hook"], [65, 2, 1, "", "register_state_dict_post_hook"], [65, 2, 1, "", "register_state_dict_pre_hook"], [65, 2, 1, "", "register_step_post_hook"], [65, 2, 1, "", "register_step_pre_hook"], [65, 2, 1, "", "state_dict"], [65, 2, 1, "", "step"], [65, 2, 1, "", "zero_grad"]], "unit_scaling.optim.AdamW": [[66, 2, 1, "", "add_param_group"], [66, 2, 1, "", "load_state_dict"], [66, 2, 1, "", "register_load_state_dict_post_hook"], [66, 2, 1, "", "register_load_state_dict_pre_hook"], [66, 2, 1, "", "register_state_dict_post_hook"], [66, 2, 1, "", "register_state_dict_pre_hook"], [66, 2, 1, "", "register_step_post_hook"], [66, 2, 1, "", "register_step_pre_hook"], [66, 2, 1, "", "state_dict"], [66, 2, 1, "", "step"], [66, 2, 1, "", "zero_grad"]], "unit_scaling.optim.SGD": [[67, 2, 1, "", "add_param_group"], [67, 2, 1, "", "load_state_dict"], [67, 2, 1, "", "register_load_state_dict_post_hook"], [67, 2, 1, "", "register_load_state_dict_pre_hook"], [67, 2, 1, "", "register_state_dict_post_hook"], [67, 2, 1, "", "register_state_dict_pre_hook"], [67, 2, 1, "", "register_step_post_hook"], [67, 2, 1, "", "register_step_pre_hook"], [67, 2, 1, "", "state_dict"], [67, 2, 1, "", "step"], [67, 2, 1, "", "zero_grad"]], "unit_scaling.parameter": [[73, 1, 1, "", "OrderedDict"], [74, 4, 1, "", "Parameter"], [75, 1, 1, "", "ParameterData"], [76, 1, 1, "", "Protocol"], [77, 1, 1, "", "Tensor"], [78, 4, 1, "", "has_parameter_data"]], "unit_scaling.parameter.OrderedDict": [[73, 2, 1, "", "clear"], [73, 2, 1, "", "copy"], [73, 2, 1, "", "fromkeys"], [73, 2, 1, "", "get"], [73, 2, 1, "", "items"], [73, 2, 1, "", "keys"], [73, 2, 1, "", "move_to_end"], [73, 2, 1, "", "pop"], [73, 2, 1, "", "popitem"], [73, 2, 1, "", "setdefault"], [73, 2, 1, "", "update"], [73, 2, 1, "", "values"]], "unit_scaling.parameter.Tensor": [[77, 3, 1, "", "H"], [77, 3, 1, "", "T"], [77, 2, 1, "", "abs"], [77, 2, 1, "", "abs_"], [77, 2, 1, "", "absolute"], [77, 2, 1, "", "absolute_"], [77, 2, 1, "", "acos"], [77, 2, 1, "", "acos_"], [77, 2, 1, "", "acosh"], [77, 2, 1, "", "acosh_"], [77, 2, 1, "", "add"], [77, 2, 1, "", "add_"], [77, 2, 1, "", "addbmm"], [77, 2, 1, "", "addbmm_"], [77, 2, 1, "", "addcdiv"], [77, 2, 1, "", "addcdiv_"], [77, 2, 1, "", "addcmul"], [77, 2, 1, "", "addcmul_"], [77, 2, 1, "", "addmm"], [77, 2, 1, "", "addmm_"], [77, 2, 1, "", "addmv"], [77, 2, 1, "", "addmv_"], [77, 2, 1, "", "addr"], [77, 2, 1, "", "addr_"], [77, 2, 1, "", "adjoint"], [77, 2, 1, "", "align_as"], [77, 2, 1, "", "align_to"], [77, 2, 1, "", "all"], [77, 2, 1, "", "allclose"], [77, 2, 1, "", "amax"], [77, 2, 1, "", "amin"], [77, 2, 1, "", "aminmax"], [77, 2, 1, "", "angle"], [77, 2, 1, "", "any"], [77, 2, 1, "", "apply_"], [77, 2, 1, "", "arccos"], [77, 2, 1, "", "arccos_"], [77, 2, 1, "", "arccosh"], [77, 2, 1, "", "arccosh_"], [77, 2, 1, "", "arcsin"], [77, 2, 1, "", "arcsin_"], [77, 2, 1, "", "arcsinh"], [77, 2, 1, "", "arcsinh_"], [77, 2, 1, "", "arctan"], [77, 2, 1, "", "arctan2"], [77, 2, 1, "", "arctan2_"], [77, 2, 1, "", "arctan_"], [77, 2, 1, "", "arctanh"], [77, 2, 1, "", "arctanh_"], [77, 2, 1, "", "argmax"], [77, 2, 1, "", "argmin"], [77, 2, 1, "", "argsort"], [77, 2, 1, "", "argwhere"], [77, 2, 1, "", "as_strided"], [77, 2, 1, "", "as_strided_"], [77, 2, 1, "", "as_strided_scatter"], [77, 2, 1, "", "as_subclass"], [77, 2, 1, "", "asin"], [77, 2, 1, "", "asin_"], [77, 2, 1, "", "asinh"], [77, 2, 1, "", "asinh_"], [77, 2, 1, "", "atan"], [77, 2, 1, "", "atan2"], [77, 2, 1, "", "atan2_"], [77, 2, 1, "", "atan_"], [77, 2, 1, "", "atanh"], [77, 2, 1, "", "atanh_"], [77, 2, 1, "", "backward"], [77, 2, 1, "", "baddbmm"], [77, 2, 1, "", "baddbmm_"], [77, 2, 1, "", "bernoulli"], [77, 2, 1, "", "bernoulli_"], [77, 2, 1, "", "bfloat16"], [77, 2, 1, "", "bincount"], [77, 2, 1, "", "bitwise_and"], [77, 2, 1, "", "bitwise_and_"], [77, 2, 1, "", "bitwise_left_shift"], [77, 2, 1, "", "bitwise_left_shift_"], [77, 2, 1, "", "bitwise_not"], [77, 2, 1, "", "bitwise_not_"], [77, 2, 1, "", "bitwise_or"], [77, 2, 1, "", "bitwise_or_"], [77, 2, 1, "", "bitwise_right_shift"], [77, 2, 1, "", "bitwise_right_shift_"], [77, 2, 1, "", "bitwise_xor"], [77, 2, 1, "", "bitwise_xor_"], [77, 2, 1, "", "bmm"], [77, 2, 1, "", "bool"], [77, 2, 1, "", "broadcast_to"], [77, 2, 1, "", "byte"], [77, 2, 1, "", "cauchy_"], [77, 2, 1, "", "cdouble"], [77, 2, 1, "", "ceil"], [77, 2, 1, "", "ceil_"], [77, 2, 1, "", "cfloat"], [77, 2, 1, "", "chalf"], [77, 2, 1, "", "char"], [77, 2, 1, "", "cholesky"], [77, 2, 1, "", "cholesky_inverse"], [77, 2, 1, "", "cholesky_solve"], [77, 2, 1, "", "chunk"], [77, 2, 1, "", "clamp"], [77, 2, 1, "", "clamp_"], [77, 2, 1, "", "clip"], [77, 2, 1, "", "clip_"], [77, 2, 1, "", "clone"], [77, 2, 1, "", "coalesce"], [77, 2, 1, "", "col_indices"], [77, 2, 1, "", "conj"], [77, 2, 1, "", "conj_physical"], [77, 2, 1, "", "conj_physical_"], [77, 2, 1, "", "contiguous"], [77, 2, 1, "", "copy_"], [77, 2, 1, "", "copysign"], [77, 2, 1, "", "copysign_"], [77, 2, 1, "", "corrcoef"], [77, 2, 1, "", "cos"], [77, 2, 1, "", "cos_"], [77, 2, 1, "", "cosh"], [77, 2, 1, "", "cosh_"], [77, 2, 1, "", "count_nonzero"], [77, 2, 1, "", "cov"], [77, 2, 1, "", "cpu"], [77, 2, 1, "", "cross"], [77, 2, 1, "", "crow_indices"], [77, 2, 1, "", "cuda"], [77, 2, 1, "", "cummax"], [77, 2, 1, "", "cummin"], [77, 2, 1, "", "cumprod"], [77, 2, 1, "", "cumprod_"], [77, 2, 1, "", "cumsum"], [77, 2, 1, "", "cumsum_"], [77, 2, 1, "", "data_ptr"], [77, 2, 1, "", "deg2rad"], [77, 2, 1, "", "deg2rad_"], [77, 2, 1, "", "dense_dim"], [77, 2, 1, "", "dequantize"], [77, 2, 1, "", "det"], [77, 2, 1, "", "detach"], [77, 2, 1, "", "detach_"], [77, 3, 1, "", "device"], [77, 2, 1, "", "diag"], [77, 2, 1, "", "diag_embed"], [77, 2, 1, "", "diagflat"], [77, 2, 1, "", "diagonal"], [77, 2, 1, "", "diagonal_scatter"], [77, 2, 1, "", "diff"], [77, 2, 1, "", "digamma"], [77, 2, 1, "", "digamma_"], [77, 2, 1, "", "dim"], [77, 2, 1, "", "dim_order"], [77, 2, 1, "", "dist"], [77, 2, 1, "", "div"], [77, 2, 1, "", "div_"], [77, 2, 1, "", "divide"], [77, 2, 1, "", "divide_"], [77, 2, 1, "", "dot"], [77, 2, 1, "", "double"], [77, 2, 1, "", "dsplit"], [77, 2, 1, "", "element_size"], [77, 2, 1, "", "eq"], [77, 2, 1, "", "eq_"], [77, 2, 1, "", "equal"], [77, 2, 1, "", "erf"], [77, 2, 1, "", "erf_"], [77, 2, 1, "", "erfc"], [77, 2, 1, "", "erfc_"], [77, 2, 1, "", "erfinv"], [77, 2, 1, "", "erfinv_"], [77, 2, 1, "", "exp"], [77, 2, 1, "", "exp2"], [77, 2, 1, "", "exp2_"], [77, 2, 1, "", "exp_"], [77, 2, 1, "", "expand"], [77, 2, 1, "", "expand_as"], [77, 2, 1, "", "expm1"], [77, 2, 1, "", "expm1_"], [77, 2, 1, "", "exponential_"], [77, 2, 1, "", "fill_"], [77, 2, 1, "", "fill_diagonal_"], [77, 2, 1, "", "fix"], [77, 2, 1, "", "fix_"], [77, 2, 1, "", "flatten"], [77, 2, 1, "", "flip"], [77, 2, 1, "", "fliplr"], [77, 2, 1, "", "flipud"], [77, 2, 1, "", "float"], [77, 2, 1, "", "float_power"], [77, 2, 1, "", "float_power_"], [77, 2, 1, "", "floor"], [77, 2, 1, "", "floor_"], [77, 2, 1, "", "floor_divide"], [77, 2, 1, "", "floor_divide_"], [77, 2, 1, "", "fmax"], [77, 2, 1, "", "fmin"], [77, 2, 1, "", "fmod"], [77, 2, 1, "", "fmod_"], [77, 2, 1, "", "frac"], [77, 2, 1, "", "frac_"], [77, 2, 1, "", "frexp"], [77, 2, 1, "", "gather"], [77, 2, 1, "", "gcd"], [77, 2, 1, "", "gcd_"], [77, 2, 1, "", "ge"], [77, 2, 1, "", "ge_"], [77, 2, 1, "", "geometric_"], [77, 2, 1, "", "geqrf"], [77, 2, 1, "", "ger"], [77, 2, 1, "", "get_device"], [77, 3, 1, "", "grad"], [77, 2, 1, "", "greater"], [77, 2, 1, "", "greater_"], [77, 2, 1, "", "greater_equal"], [77, 2, 1, "", "greater_equal_"], [77, 2, 1, "", "gt"], [77, 2, 1, "", "gt_"], [77, 2, 1, "", "half"], [77, 2, 1, "", "hardshrink"], [77, 2, 1, "", "has_names"], [77, 2, 1, "", "heaviside"], [77, 2, 1, "", "heaviside_"], [77, 2, 1, "", "histc"], [77, 2, 1, "", "histogram"], [77, 2, 1, "", "hsplit"], [77, 2, 1, "", "hypot"], [77, 2, 1, "", "hypot_"], [77, 2, 1, "", "i0"], [77, 2, 1, "", "i0_"], [77, 2, 1, "", "igamma"], [77, 2, 1, "", "igamma_"], [77, 2, 1, "", "igammac"], [77, 2, 1, "", "igammac_"], [77, 3, 1, "", "imag"], [77, 2, 1, "", "index_add"], [77, 2, 1, "", "index_add_"], [77, 2, 1, "", "index_copy"], [77, 2, 1, "", "index_copy_"], [77, 2, 1, "", "index_fill"], [77, 2, 1, "", "index_fill_"], [77, 2, 1, "", "index_put"], [77, 2, 1, "", "index_put_"], [77, 2, 1, "", "index_reduce_"], [77, 2, 1, "", "index_select"], [77, 2, 1, "", "indices"], [77, 2, 1, "", "inner"], [77, 2, 1, "", "int"], [77, 2, 1, "", "int_repr"], [77, 2, 1, "", "inverse"], [77, 2, 1, "", "ipu"], [77, 2, 1, "", "is_coalesced"], [77, 2, 1, "", "is_complex"], [77, 2, 1, "", "is_conj"], [77, 2, 1, "", "is_contiguous"], [77, 3, 1, "", "is_cpu"], [77, 3, 1, "", "is_cuda"], [77, 2, 1, "", "is_floating_point"], [77, 2, 1, "", "is_inference"], [77, 3, 1, "", "is_ipu"], [77, 3, 1, "", "is_leaf"], [77, 3, 1, "", "is_meta"], [77, 3, 1, "", "is_mps"], [77, 2, 1, "", "is_neg"], [77, 2, 1, "", "is_pinned"], [77, 3, 1, "", "is_quantized"], [77, 2, 1, "", "is_set_to"], [77, 2, 1, "", "is_shared"], [77, 2, 1, "", "is_signed"], [77, 3, 1, "", "is_sparse"], [77, 3, 1, "", "is_sparse_csr"], [77, 3, 1, "", "is_xla"], [77, 3, 1, "", "is_xpu"], [77, 2, 1, "", "isclose"], [77, 2, 1, "", "isfinite"], [77, 2, 1, "", "isinf"], [77, 2, 1, "", "isnan"], [77, 2, 1, "", "isneginf"], [77, 2, 1, "", "isposinf"], [77, 2, 1, "", "isreal"], [77, 2, 1, "", "istft"], [77, 2, 1, "", "item"], [77, 3, 1, "", "itemsize"], [77, 2, 1, "", "kron"], [77, 2, 1, "", "kthvalue"], [77, 2, 1, "", "lcm"], [77, 2, 1, "", "lcm_"], [77, 2, 1, "", "ldexp"], [77, 2, 1, "", "ldexp_"], [77, 2, 1, "", "le"], [77, 2, 1, "", "le_"], [77, 2, 1, "", "lerp"], [77, 2, 1, "", "lerp_"], [77, 2, 1, "", "less"], [77, 2, 1, "", "less_"], [77, 2, 1, "", "less_equal"], [77, 2, 1, "", "less_equal_"], [77, 2, 1, "", "lgamma"], [77, 2, 1, "", "lgamma_"], [77, 2, 1, "", "log"], [77, 2, 1, "", "log10"], [77, 2, 1, "", "log10_"], [77, 2, 1, "", "log1p"], [77, 2, 1, "", "log1p_"], [77, 2, 1, "", "log2"], [77, 2, 1, "", "log2_"], [77, 2, 1, "", "log_"], [77, 2, 1, "", "log_normal_"], [77, 2, 1, "", "logaddexp"], [77, 2, 1, "", "logaddexp2"], [77, 2, 1, "", "logcumsumexp"], [77, 2, 1, "", "logdet"], [77, 2, 1, "", "logical_and"], [77, 2, 1, "", "logical_and_"], [77, 2, 1, "", "logical_not"], [77, 2, 1, "", "logical_not_"], [77, 2, 1, "", "logical_or"], [77, 2, 1, "", "logical_or_"], [77, 2, 1, "", "logical_xor"], [77, 2, 1, "", "logical_xor_"], [77, 2, 1, "", "logit"], [77, 2, 1, "", "logit_"], [77, 2, 1, "", "logsumexp"], [77, 2, 1, "", "long"], [77, 2, 1, "", "lt"], [77, 2, 1, "", "lt_"], [77, 2, 1, "", "lu"], [77, 2, 1, "", "lu_solve"], [77, 3, 1, "", "mH"], [77, 3, 1, "", "mT"], [77, 2, 1, "", "map_"], [77, 2, 1, "", "masked_fill"], [77, 2, 1, "", "masked_fill_"], [77, 2, 1, "", "masked_scatter"], [77, 2, 1, "", "masked_scatter_"], [77, 2, 1, "", "masked_select"], [77, 2, 1, "", "matmul"], [77, 2, 1, "", "matrix_exp"], [77, 2, 1, "", "matrix_power"], [77, 2, 1, "", "max"], [77, 2, 1, "", "maximum"], [77, 2, 1, "", "mean"], [77, 2, 1, "", "median"], [77, 2, 1, "", "min"], [77, 2, 1, "", "minimum"], [77, 2, 1, "", "mm"], [77, 2, 1, "", "mode"], [77, 2, 1, "", "module_load"], [77, 2, 1, "", "moveaxis"], [77, 2, 1, "", "movedim"], [77, 2, 1, "", "msort"], [77, 2, 1, "", "mul"], [77, 2, 1, "", "mul_"], [77, 2, 1, "", "multinomial"], [77, 2, 1, "", "multiply"], [77, 2, 1, "", "multiply_"], [77, 2, 1, "", "mv"], [77, 2, 1, "", "mvlgamma"], [77, 2, 1, "", "mvlgamma_"], [77, 3, 1, "", "names"], [77, 2, 1, "", "nan_to_num"], [77, 2, 1, "", "nan_to_num_"], [77, 2, 1, "", "nanmean"], [77, 2, 1, "", "nanmedian"], [77, 2, 1, "", "nanquantile"], [77, 2, 1, "", "nansum"], [77, 2, 1, "", "narrow"], [77, 2, 1, "", "narrow_copy"], [77, 3, 1, "", "nbytes"], [77, 3, 1, "", "ndim"], [77, 2, 1, "", "ndimension"], [77, 2, 1, "", "ne"], [77, 2, 1, "", "ne_"], [77, 2, 1, "", "neg"], [77, 2, 1, "", "neg_"], [77, 2, 1, "", "negative"], [77, 2, 1, "", "negative_"], [77, 2, 1, "", "nelement"], [77, 2, 1, "", "new_empty"], [77, 2, 1, "", "new_empty_strided"], [77, 2, 1, "", "new_full"], [77, 2, 1, "", "new_ones"], [77, 2, 1, "", "new_tensor"], [77, 2, 1, "", "new_zeros"], [77, 2, 1, "", "nextafter"], [77, 2, 1, "", "nextafter_"], [77, 2, 1, "", "nonzero"], [77, 2, 1, "", "nonzero_static"], [77, 2, 1, "", "norm"], [77, 2, 1, "", "normal_"], [77, 2, 1, "", "not_equal"], [77, 2, 1, "", "not_equal_"], [77, 2, 1, "", "numel"], [77, 2, 1, "", "numpy"], [77, 2, 1, "", "orgqr"], [77, 2, 1, "", "ormqr"], [77, 2, 1, "", "outer"], [77, 2, 1, "", "permute"], [77, 2, 1, "", "pin_memory"], [77, 2, 1, "", "pinverse"], [77, 2, 1, "", "polygamma"], [77, 2, 1, "", "polygamma_"], [77, 2, 1, "", "positive"], [77, 2, 1, "", "pow"], [77, 2, 1, "", "pow_"], [77, 2, 1, "", "prod"], [77, 2, 1, "", "put"], [77, 2, 1, "", "put_"], [77, 2, 1, "", "q_per_channel_axis"], [77, 2, 1, "", "q_per_channel_scales"], [77, 2, 1, "", "q_per_channel_zero_points"], [77, 2, 1, "", "q_scale"], [77, 2, 1, "", "q_zero_point"], [77, 2, 1, "", "qr"], [77, 2, 1, "", "qscheme"], [77, 2, 1, "", "quantile"], [77, 2, 1, "", "rad2deg"], [77, 2, 1, "", "rad2deg_"], [77, 2, 1, "", "random_"], [77, 2, 1, "", "ravel"], [77, 3, 1, "", "real"], [77, 2, 1, "", "reciprocal"], [77, 2, 1, "", "reciprocal_"], [77, 2, 1, "", "record_stream"], [77, 2, 1, "", "refine_names"], [77, 2, 1, "", "register_hook"], [77, 2, 1, "", "register_post_accumulate_grad_hook"], [77, 2, 1, "", "remainder"], [77, 2, 1, "", "remainder_"], [77, 2, 1, "", "rename"], [77, 2, 1, "", "rename_"], [77, 2, 1, "", "renorm"], [77, 2, 1, "", "renorm_"], [77, 2, 1, "", "repeat"], [77, 2, 1, "", "repeat_interleave"], [77, 3, 1, "", "requires_grad"], [77, 2, 1, "", "requires_grad_"], [77, 2, 1, "", "reshape"], [77, 2, 1, "", "reshape_as"], [77, 2, 1, "", "resize_"], [77, 2, 1, "", "resize_as_"], [77, 2, 1, "", "resolve_conj"], [77, 2, 1, "", "resolve_neg"], [77, 2, 1, "", "retain_grad"], [77, 3, 1, "", "retains_grad"], [77, 2, 1, "", "roll"], [77, 2, 1, "", "rot90"], [77, 2, 1, "", "round"], [77, 2, 1, "", "round_"], [77, 2, 1, "", "rsqrt"], [77, 2, 1, "", "rsqrt_"], [77, 2, 1, "", "scatter"], [77, 2, 1, "", "scatter_"], [77, 2, 1, "", "scatter_add"], [77, 2, 1, "", "scatter_add_"], [77, 2, 1, "", "scatter_reduce"], [77, 2, 1, "", "scatter_reduce_"], [77, 2, 1, "", "select"], [77, 2, 1, "", "select_scatter"], [77, 2, 1, "", "set_"], [77, 2, 1, "", "sgn"], [77, 2, 1, "", "sgn_"], [77, 3, 1, "", "shape"], [77, 2, 1, "", "share_memory_"], [77, 2, 1, "", "short"], [77, 2, 1, "", "sigmoid"], [77, 2, 1, "", "sigmoid_"], [77, 2, 1, "", "sign"], [77, 2, 1, "", "sign_"], [77, 2, 1, "", "signbit"], [77, 2, 1, "", "sin"], [77, 2, 1, "", "sin_"], [77, 2, 1, "", "sinc"], [77, 2, 1, "", "sinc_"], [77, 2, 1, "", "sinh"], [77, 2, 1, "", "sinh_"], [77, 2, 1, "", "size"], [77, 2, 1, "", "slice_scatter"], [77, 2, 1, "", "slogdet"], [77, 2, 1, "", "smm"], [77, 2, 1, "", "softmax"], [77, 2, 1, "", "sort"], [77, 2, 1, "", "sparse_dim"], [77, 2, 1, "", "sparse_mask"], [77, 2, 1, "", "sparse_resize_"], [77, 2, 1, "", "sparse_resize_and_clear_"], [77, 2, 1, "", "split"], [77, 2, 1, "", "sqrt"], [77, 2, 1, "", "sqrt_"], [77, 2, 1, "", "square"], [77, 2, 1, "", "square_"], [77, 2, 1, "", "squeeze"], [77, 2, 1, "", "squeeze_"], [77, 2, 1, "", "sspaddmm"], [77, 2, 1, "", "std"], [77, 2, 1, "", "stft"], [77, 2, 1, "", "storage"], [77, 2, 1, "", "storage_offset"], [77, 2, 1, "", "storage_type"], [77, 2, 1, "", "stride"], [77, 2, 1, "", "sub"], [77, 2, 1, "", "sub_"], [77, 2, 1, "", "subtract"], [77, 2, 1, "", "subtract_"], [77, 2, 1, "", "sum"], [77, 2, 1, "", "sum_to_size"], [77, 2, 1, "", "svd"], [77, 2, 1, "", "swapaxes"], [77, 2, 1, "", "swapaxes_"], [77, 2, 1, "", "swapdims"], [77, 2, 1, "", "swapdims_"], [77, 2, 1, "", "t"], [77, 2, 1, "", "t_"], [77, 2, 1, "", "take"], [77, 2, 1, "", "take_along_dim"], [77, 2, 1, "", "tan"], [77, 2, 1, "", "tan_"], [77, 2, 1, "", "tanh"], [77, 2, 1, "", "tanh_"], [77, 2, 1, "", "tensor_split"], [77, 2, 1, "", "tile"], [77, 2, 1, "", "to"], [77, 2, 1, "", "to_dense"], [77, 2, 1, "", "to_mkldnn"], [77, 2, 1, "", "to_padded_tensor"], [77, 2, 1, "", "to_sparse"], [77, 2, 1, "", "to_sparse_bsc"], [77, 2, 1, "", "to_sparse_bsr"], [77, 2, 1, "", "to_sparse_coo"], [77, 2, 1, "", "to_sparse_csc"], [77, 2, 1, "", "to_sparse_csr"], [77, 2, 1, "", "tolist"], [77, 2, 1, "", "topk"], [77, 2, 1, "", "trace"], [77, 2, 1, "", "transpose"], [77, 2, 1, "", "transpose_"], [77, 2, 1, "", "triangular_solve"], [77, 2, 1, "", "tril"], [77, 2, 1, "", "tril_"], [77, 2, 1, "", "triu"], [77, 2, 1, "", "triu_"], [77, 2, 1, "", "true_divide"], [77, 2, 1, "", "true_divide_"], [77, 2, 1, "", "trunc"], [77, 2, 1, "", "trunc_"], [77, 2, 1, "", "type"], [77, 2, 1, "", "type_as"], [77, 2, 1, "", "unbind"], [77, 2, 1, "", "unflatten"], [77, 2, 1, "", "unfold"], [77, 2, 1, "", "uniform_"], [77, 2, 1, "", "unique"], [77, 2, 1, "", "unique_consecutive"], [77, 2, 1, "", "unsafe_chunk"], [77, 2, 1, "", "unsafe_split"], [77, 2, 1, "", "unsqueeze"], [77, 2, 1, "", "unsqueeze_"], [77, 2, 1, "", "untyped_storage"], [77, 2, 1, "", "values"], [77, 2, 1, "", "var"], [77, 2, 1, "", "vdot"], [77, 2, 1, "", "view"], [77, 2, 1, "", "view_as"], [77, 2, 1, "", "vsplit"], [77, 2, 1, "", "where"], [77, 2, 1, "", "xlogy"], [77, 2, 1, "", "xlogy_"], [77, 2, 1, "", "xpu"], [77, 2, 1, "", "zero_"]], "unit_scaling.scale": [[80, 4, 1, "", "scale_bwd"], [81, 4, 1, "", "scale_fwd"]], "unit_scaling.transforms": [[84, 1, 1, "", "Metrics"], [85, 4, 1, "", "compile"], [86, 4, 1, "", "prune_non_float_tensors"], [87, 4, 1, "", "prune_same_scale_tensors"], [88, 4, 1, "", "prune_selected_nodes"], [89, 4, 1, "", "simulate_format"], [90, 4, 1, "", "simulate_fp8"], [91, 4, 1, "", "track_scales"], [92, 4, 1, "", "unit_scale"], [93, 0, 0, "-", "utils"]], "unit_scaling.transforms.Metrics": [[84, 1, 1, "", "Data"]], "unit_scaling.transforms.utils": [[94, 4, 1, "", "apply_transform"], [95, 4, 1, "", "patch_to_expand_modules"], [96, 4, 1, "", "replace_node_with_function"], [97, 4, 1, "", "torch_nn_modules_to_user_modules"]], "unit_scaling.utils": [[99, 1, 1, "", "ScalePair"], [100, 1, 1, "", "ScaleTracker"], [101, 1, 1, "", "ScaleTrackingInterpreter"], [102, 4, 1, "", "analyse_module"]], "unit_scaling.utils.ScaleTracker": [[100, 2, 1, "", "backward"], [100, 2, 1, "", "jvp"], [100, 2, 1, "", "mark_dirty"], [100, 2, 1, "", "mark_non_differentiable"], [100, 2, 1, "", "save_for_backward"], [100, 2, 1, "", "save_for_forward"], [100, 2, 1, "", "set_materialize_grads"], [100, 2, 1, "", "setup_context"], [100, 2, 1, "", "vjp"], [100, 2, 1, "", "vmap"]], "unit_scaling.utils.ScaleTrackingInterpreter": [[101, 2, 1, "", "boxed_run"], [101, 2, 1, "", "call_function"], [101, 2, 1, "", "call_method"], [101, 2, 1, "", "call_module"], [101, 2, 1, "", "fetch_args_kwargs_from_env"], [101, 2, 1, "", "fetch_attr"], [101, 2, 1, "", "get_attr"], [101, 2, 1, "", "map_nodes_to_values"], [101, 2, 1, "", "output"], [101, 2, 1, "", "placeholder"], [101, 2, 1, "", "run"], [101, 2, 1, "", "run_node"]]}, "objtypes": {"0": "py:module", "1": "py:class", "2": "py:method", "3": "py:attribute", "4": "py:function", "5": "py:property"}, "objnames": {"0": ["py", "module", "Python module"], "1": ["py", "class", "Python class"], "2": ["py", "method", "Python method"], "3": ["py", "attribute", "Python attribute"], "4": ["py", "function", "Python function"], "5": ["py", "property", "Python property"]}, "titleterms": {"api": 0, "refer": 0, "unit": [1, 104, 107], "scale": [1, 79, 80, 81, 104, 106, 107], "blog": 1, "almost": [1, 106], "dot": [1, 106], "product": [1, 106], "self": 1, "attent": [1, 106], "develop": [2, 104], "unit_sc": [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103], "crossentropyloss": 4, "depthmodulelist": 5, "depthsequenti": 6, "dropout": [7, 48], "embed": [8, 49], "gelu": [9, 50], "layernorm": 10, "linear": [11, 52], "linearreadout": 12, "mhsa": 13, "mlp": 14, "paramet": [15, 72, 73, 74, 75, 76, 77, 78], "rmsnorm": 16, "silu": [17, 61], "softmax": [18, 63], "transformerdecod": 19, "transformerlay": 20, "analysi": [21, 22, 23, 24, 25], "example_batch": 22, "graph_to_datafram": 23, "plot": 24, "visualis": [25, 103], "constraint": [26, 27, 28, 29, 30, 31, 32, 33, 34], "amean": 27, "apply_constraint": 28, "gmean": 29, "hmean": 30, "to_grad_input_scal": 31, "to_left_grad_scal": 32, "to_output_scal": 33, "to_right_grad_scal": 34, "core": [35, 36, 37, 38, 39, 40], "function": [36, 37, 38, 39, 40, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], "logarithmic_interpol": 37, "rm": 38, "scale_elementwis": 39, "transformer_residual_scaling_rul": [40, 82], "format": [41, 42, 43, 44], "fpformat": 42, "format_to_tupl": 43, "tuple_to_format": 44, "add": 46, "cross_entropi": 47, "layer_norm": 51, "linear_readout": 53, "matmul": 54, "mse_loss": 55, "residual_add": 56, "residual_appli": 57, "residual_split": 58, "rms_norm": 59, "scaled_dot_product_attent": 60, "silu_glu": 62, "optim": [64, 65, 66, 67, 68, 69, 70, 71], "adam": 65, "adamw": 66, "sgd": 67, "lr_scale_for_depth": 68, "lr_scale_func_adam": 69, "lr_scale_func_sgd": 70, "scaled_paramet": 71, "ordereddict": 73, "parameterdata": 75, "protocol": 76, "tensor": 77, "has_parameter_data": 78, "scale_bwd": 80, "scale_fwd": 81, "transform": [83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97], "metric": 84, "compil": 85, "prune_non_float_tensor": 86, "prune_same_scale_tensor": 87, "prune_selected_nod": 88, "simulate_format": 89, "simulate_fp8": 90, "track_scal": 91, "unit_scal": 92, "util": [93, 94, 95, 96, 97, 98, 99, 100, 101, 102], "apply_transform": 94, "patch_to_expand_modul": 95, "replace_node_with_funct": 96, "torch_nn_modules_to_user_modul": 97, "scalepair": 99, "scaletrack": 100, "scaletrackinginterpret": 101, "analyse_modul": 102, "instal": [104, 107], "get": 104, "start": 104, "content": 104, "limit": 105, "where": 106, "doe": 106, "d_": 106, "seq": 106, "e": 106, "1": 106, "2": 106, "come": 106, "from": 106, "work": 106, "No": 106, "conclus": 106, "user": 107, "guid": 107, "what": 107, "i": 107, "how": 107, "model": 107, "kei": 107, "consider": 107, "optimis": 107}, "envversion": {"sphinx.domains.c": 2, "sphinx.domains.changeset": 1, "sphinx.domains.citation": 1, "sphinx.domains.cpp": 8, "sphinx.domains.index": 1, "sphinx.domains.javascript": 2, "sphinx.domains.math": 2, "sphinx.domains.python": 3, "sphinx.domains.rst": 2, "sphinx.domains.std": 2, "sphinx.ext.intersphinx": 1, "sphinx.ext.viewcode": 1, "sphinx": 57}, "alltitles": {"API reference": [[0, "api-reference"]], "Unit Scaling blog": [[1, "unit-scaling-blog"]], "Almost scaled dot-product self attention": [[1, "almost-scaled-dot-product-self-attention"]], "Development": [[2, "development"], [104, "development"]], "unit_scaling": [[3, "module-unit_scaling"]], "unit_scaling.CrossEntropyLoss": [[4, "unit-scaling-crossentropyloss"]], "unit_scaling.DepthModuleList": [[5, "unit-scaling-depthmodulelist"]], "unit_scaling.DepthSequential": [[6, "unit-scaling-depthsequential"]], "unit_scaling.Dropout": [[7, "unit-scaling-dropout"]], "unit_scaling.Embedding": [[8, "unit-scaling-embedding"]], "unit_scaling.GELU": [[9, "unit-scaling-gelu"]], "unit_scaling.LayerNorm": [[10, "unit-scaling-layernorm"]], "unit_scaling.Linear": [[11, "unit-scaling-linear"]], "unit_scaling.LinearReadout": [[12, "unit-scaling-linearreadout"]], "unit_scaling.MHSA": [[13, "unit-scaling-mhsa"]], "unit_scaling.MLP": [[14, "unit-scaling-mlp"]], "unit_scaling.Parameter": [[15, "unit-scaling-parameter"]], "unit_scaling.RMSNorm": [[16, "unit-scaling-rmsnorm"]], "unit_scaling.SiLU": [[17, "unit-scaling-silu"]], "unit_scaling.Softmax": [[18, "unit-scaling-softmax"]], "unit_scaling.TransformerDecoder": [[19, "unit-scaling-transformerdecoder"]], "unit_scaling.TransformerLayer": [[20, "unit-scaling-transformerlayer"]], "unit_scaling.analysis": [[21, "module-unit_scaling.analysis"]], "unit_scaling.analysis.example_batch": [[22, "unit-scaling-analysis-example-batch"]], "unit_scaling.analysis.graph_to_dataframe": [[23, "unit-scaling-analysis-graph-to-dataframe"]], "unit_scaling.analysis.plot": [[24, "unit-scaling-analysis-plot"]], "unit_scaling.analysis.visualiser": [[25, "unit-scaling-analysis-visualiser"]], "unit_scaling.constraints": [[26, "module-unit_scaling.constraints"]], "unit_scaling.constraints.amean": [[27, "unit-scaling-constraints-amean"]], "unit_scaling.constraints.apply_constraint": [[28, "unit-scaling-constraints-apply-constraint"]], "unit_scaling.constraints.gmean": [[29, "unit-scaling-constraints-gmean"]], "unit_scaling.constraints.hmean": [[30, "unit-scaling-constraints-hmean"]], "unit_scaling.constraints.to_grad_input_scale": [[31, "unit-scaling-constraints-to-grad-input-scale"]], "unit_scaling.constraints.to_left_grad_scale": [[32, "unit-scaling-constraints-to-left-grad-scale"]], "unit_scaling.constraints.to_output_scale": [[33, "unit-scaling-constraints-to-output-scale"]], "unit_scaling.constraints.to_right_grad_scale": [[34, "unit-scaling-constraints-to-right-grad-scale"]], "unit_scaling.core": [[35, "module-unit_scaling.core"]], "unit_scaling.core.functional": [[36, "module-unit_scaling.core.functional"]], "unit_scaling.core.functional.logarithmic_interpolation": [[37, "unit-scaling-core-functional-logarithmic-interpolation"]], "unit_scaling.core.functional.rms": [[38, "unit-scaling-core-functional-rms"]], "unit_scaling.core.functional.scale_elementwise": [[39, "unit-scaling-core-functional-scale-elementwise"]], "unit_scaling.core.functional.transformer_residual_scaling_rule": [[40, "unit-scaling-core-functional-transformer-residual-scaling-rule"]], "unit_scaling.formats": [[41, "module-unit_scaling.formats"]], "unit_scaling.formats.FPFormat": [[42, "unit-scaling-formats-fpformat"]], "unit_scaling.formats.format_to_tuple": [[43, "unit-scaling-formats-format-to-tuple"]], "unit_scaling.formats.tuple_to_format": [[44, "unit-scaling-formats-tuple-to-format"]], "unit_scaling.functional": [[45, "module-unit_scaling.functional"]], "unit_scaling.functional.add": [[46, "unit-scaling-functional-add"]], "unit_scaling.functional.cross_entropy": [[47, "unit-scaling-functional-cross-entropy"]], "unit_scaling.functional.dropout": [[48, "unit-scaling-functional-dropout"]], "unit_scaling.functional.embedding": [[49, "unit-scaling-functional-embedding"]], "unit_scaling.functional.gelu": [[50, "unit-scaling-functional-gelu"]], "unit_scaling.functional.layer_norm": [[51, "unit-scaling-functional-layer-norm"]], "unit_scaling.functional.linear": [[52, "unit-scaling-functional-linear"]], "unit_scaling.functional.linear_readout": [[53, "unit-scaling-functional-linear-readout"]], "unit_scaling.functional.matmul": [[54, "unit-scaling-functional-matmul"]], "unit_scaling.functional.mse_loss": [[55, "unit-scaling-functional-mse-loss"]], "unit_scaling.functional.residual_add": [[56, "unit-scaling-functional-residual-add"]], "unit_scaling.functional.residual_apply": [[57, "unit-scaling-functional-residual-apply"]], "unit_scaling.functional.residual_split": [[58, "unit-scaling-functional-residual-split"]], "unit_scaling.functional.rms_norm": [[59, "unit-scaling-functional-rms-norm"]], "unit_scaling.functional.scaled_dot_product_attention": [[60, "unit-scaling-functional-scaled-dot-product-attention"]], "unit_scaling.functional.silu": [[61, "unit-scaling-functional-silu"]], "unit_scaling.functional.silu_glu": [[62, "unit-scaling-functional-silu-glu"]], "unit_scaling.functional.softmax": [[63, "unit-scaling-functional-softmax"]], "unit_scaling.optim": [[64, "module-unit_scaling.optim"]], "unit_scaling.optim.Adam": [[65, "unit-scaling-optim-adam"]], "unit_scaling.optim.AdamW": [[66, "unit-scaling-optim-adamw"]], "unit_scaling.optim.SGD": [[67, "unit-scaling-optim-sgd"]], "unit_scaling.optim.lr_scale_for_depth": [[68, "unit-scaling-optim-lr-scale-for-depth"]], "unit_scaling.optim.lr_scale_func_adam": [[69, "unit-scaling-optim-lr-scale-func-adam"]], "unit_scaling.optim.lr_scale_func_sgd": [[70, "unit-scaling-optim-lr-scale-func-sgd"]], "unit_scaling.optim.scaled_parameters": [[71, "unit-scaling-optim-scaled-parameters"]], "unit_scaling.parameter": [[72, "module-unit_scaling.parameter"]], "unit_scaling.parameter.OrderedDict": [[73, "unit-scaling-parameter-ordereddict"]], "unit_scaling.parameter.Parameter": [[74, "unit-scaling-parameter-parameter"]], "unit_scaling.parameter.ParameterData": [[75, "unit-scaling-parameter-parameterdata"]], "unit_scaling.parameter.Protocol": [[76, "unit-scaling-parameter-protocol"]], "unit_scaling.parameter.Tensor": [[77, "unit-scaling-parameter-tensor"]], "unit_scaling.parameter.has_parameter_data": [[78, "unit-scaling-parameter-has-parameter-data"]], "unit_scaling.scale": [[79, "module-unit_scaling.scale"]], "unit_scaling.scale.scale_bwd": [[80, "unit-scaling-scale-scale-bwd"]], "unit_scaling.scale.scale_fwd": [[81, "unit-scaling-scale-scale-fwd"]], "unit_scaling.transformer_residual_scaling_rule": [[82, "unit-scaling-transformer-residual-scaling-rule"]], "unit_scaling.transforms": [[83, "module-unit_scaling.transforms"]], "unit_scaling.transforms.Metrics": [[84, "unit-scaling-transforms-metrics"]], "unit_scaling.transforms.compile": [[85, "unit-scaling-transforms-compile"]], "unit_scaling.transforms.prune_non_float_tensors": [[86, "unit-scaling-transforms-prune-non-float-tensors"]], "unit_scaling.transforms.prune_same_scale_tensors": [[87, "unit-scaling-transforms-prune-same-scale-tensors"]], "unit_scaling.transforms.prune_selected_nodes": [[88, "unit-scaling-transforms-prune-selected-nodes"]], "unit_scaling.transforms.simulate_format": [[89, "unit-scaling-transforms-simulate-format"]], "unit_scaling.transforms.simulate_fp8": [[90, "unit-scaling-transforms-simulate-fp8"]], "unit_scaling.transforms.track_scales": [[91, "unit-scaling-transforms-track-scales"]], "unit_scaling.transforms.unit_scale": [[92, "unit-scaling-transforms-unit-scale"]], "unit_scaling.transforms.utils": [[93, "module-unit_scaling.transforms.utils"]], "unit_scaling.transforms.utils.apply_transform": [[94, "unit-scaling-transforms-utils-apply-transform"]], "unit_scaling.transforms.utils.patch_to_expand_modules": [[95, "unit-scaling-transforms-utils-patch-to-expand-modules"]], "unit_scaling.transforms.utils.replace_node_with_function": [[96, "unit-scaling-transforms-utils-replace-node-with-function"]], "unit_scaling.transforms.utils.torch_nn_modules_to_user_modules": [[97, "unit-scaling-transforms-utils-torch-nn-modules-to-user-modules"]], "unit_scaling.utils": [[98, "module-unit_scaling.utils"]], "unit_scaling.utils.ScalePair": [[99, "unit-scaling-utils-scalepair"]], "unit_scaling.utils.ScaleTracker": [[100, "unit-scaling-utils-scaletracker"]], "unit_scaling.utils.ScaleTrackingInterpreter": [[101, "unit-scaling-utils-scaletrackinginterpreter"]], "unit_scaling.utils.analyse_module": [[102, "unit-scaling-utils-analyse-module"]], "unit_scaling.visualiser": [[103, "unit-scaling-visualiser"]], "Unit Scaling": [[104, "unit-scaling"]], "Installation": [[104, "installation"], [107, "installation"]], "Getting Started": [[104, "getting-started"]], "Contents": [[104, null]], "Limitations": [[105, "limitations"]], "Almost-scaled dot-product attention": [[106, "almost-scaled-dot-product-attention"]], "Where does (d_{seq}/e)^{1/2} come from?": [[106, "where-does-d-seq-e-1-2-come-from"]], "Does it work? \u2026No!": [[106, "does-it-work-no"]], "Conclusion": [[106, "conclusion"]], "User guide": [[107, "user-guide"]], "What is unit scaling?": [[107, "what-is-unit-scaling"]], "How to unit-scale a model": [[107, "how-to-unit-scale-a-model"]], "Key considerations for unit scaling": [[107, "key-considerations-for-unit-scaling"]], "Optimising unit-scaled models": [[107, "optimising-unit-scaled-models"]]}, "indexentries": {"module": [[3, "module-unit_scaling"], [21, "module-unit_scaling.analysis"], [26, "module-unit_scaling.constraints"], [35, "module-unit_scaling.core"], [36, "module-unit_scaling.core.functional"], [41, "module-unit_scaling.formats"], [45, "module-unit_scaling.functional"], [64, "module-unit_scaling.optim"], [72, "module-unit_scaling.parameter"], [79, "module-unit_scaling.scale"], [83, "module-unit_scaling.transforms"], [93, "module-unit_scaling.transforms.utils"], [98, "module-unit_scaling.utils"]], "unit_scaling": [[3, "module-unit_scaling"]], "crossentropyloss (class in unit_scaling)": [[4, "unit_scaling.CrossEntropyLoss"]], "depthmodulelist (class in unit_scaling)": [[5, "unit_scaling.DepthModuleList"]], "append() (unit_scaling.depthmodulelist method)": [[5, "unit_scaling.DepthModuleList.append"]], "extend() (unit_scaling.depthmodulelist method)": [[5, "unit_scaling.DepthModuleList.extend"]], "insert() (unit_scaling.depthmodulelist method)": [[5, "unit_scaling.DepthModuleList.insert"]], "depthsequential (class in unit_scaling)": [[6, "unit_scaling.DepthSequential"]], "append() (unit_scaling.depthsequential method)": [[6, "unit_scaling.DepthSequential.append"]], "dropout (class in unit_scaling)": [[7, "unit_scaling.Dropout"]], "embedding (class in unit_scaling)": [[8, "unit_scaling.Embedding"]], "from_pretrained() (unit_scaling.embedding class method)": [[8, "unit_scaling.Embedding.from_pretrained"]], "weight (unit_scaling.embedding attribute)": [[8, "unit_scaling.Embedding.weight"]], "gelu (class in unit_scaling)": [[9, "unit_scaling.GELU"]], "layernorm (class in unit_scaling)": [[10, "unit_scaling.LayerNorm"]], "bias (unit_scaling.layernorm attribute)": [[10, "unit_scaling.LayerNorm.bias"]], "weight (unit_scaling.layernorm attribute)": [[10, "unit_scaling.LayerNorm.weight"]], "linear (class in unit_scaling)": [[11, "unit_scaling.Linear"]], "bias (unit_scaling.linear attribute)": [[11, "unit_scaling.Linear.bias"]], "weight (unit_scaling.linear attribute)": [[11, "unit_scaling.Linear.weight"]], "linearreadout (class in unit_scaling)": [[12, "unit_scaling.LinearReadout"]], "bias (unit_scaling.linearreadout attribute)": [[12, "unit_scaling.LinearReadout.bias"]], "weight (unit_scaling.linearreadout attribute)": [[12, "unit_scaling.LinearReadout.weight"]], "mhsa (class in unit_scaling)": [[13, "unit_scaling.MHSA"]], "mlp (class in unit_scaling)": [[14, "unit_scaling.MLP"]], "parameter() (in module unit_scaling)": [[15, "unit_scaling.Parameter"]], "rmsnorm (class in unit_scaling)": [[16, "unit_scaling.RMSNorm"]], "weight (unit_scaling.rmsnorm attribute)": [[16, "unit_scaling.RMSNorm.weight"]], "silu (class in unit_scaling)": [[17, "unit_scaling.SiLU"]], "softmax (class in unit_scaling)": [[18, "unit_scaling.Softmax"]], "transformerdecoder (class in unit_scaling)": [[19, "unit_scaling.TransformerDecoder"]], "append() (unit_scaling.transformerdecoder method)": [[19, "unit_scaling.TransformerDecoder.append"]], "transformerlayer (class in unit_scaling)": [[20, "unit_scaling.TransformerLayer"]], "unit_scaling.analysis": [[21, "module-unit_scaling.analysis"]], "example_batch() (in module unit_scaling.analysis)": [[22, "unit_scaling.analysis.example_batch"]], "graph_to_dataframe() (in module unit_scaling.analysis)": [[23, "unit_scaling.analysis.graph_to_dataframe"]], "plot() (in module unit_scaling.analysis)": [[24, "unit_scaling.analysis.plot"]], "visualiser() (in module unit_scaling.analysis)": [[25, "unit_scaling.analysis.visualiser"]], "unit_scaling.constraints": [[26, "module-unit_scaling.constraints"]], "amean() (in module unit_scaling.constraints)": [[27, "unit_scaling.constraints.amean"]], "apply_constraint() (in module unit_scaling.constraints)": [[28, "unit_scaling.constraints.apply_constraint"]], "gmean() (in module unit_scaling.constraints)": [[29, "unit_scaling.constraints.gmean"]], "hmean() (in module unit_scaling.constraints)": [[30, "unit_scaling.constraints.hmean"]], "to_grad_input_scale() (in module unit_scaling.constraints)": [[31, "unit_scaling.constraints.to_grad_input_scale"]], "to_left_grad_scale() (in module unit_scaling.constraints)": [[32, "unit_scaling.constraints.to_left_grad_scale"]], "to_output_scale() (in module unit_scaling.constraints)": [[33, "unit_scaling.constraints.to_output_scale"]], "to_right_grad_scale() (in module unit_scaling.constraints)": [[34, "unit_scaling.constraints.to_right_grad_scale"]], "unit_scaling.core": [[35, "module-unit_scaling.core"]], "unit_scaling.core.functional": [[36, "module-unit_scaling.core.functional"]], "logarithmic_interpolation() (in module unit_scaling.core.functional)": [[37, "unit_scaling.core.functional.logarithmic_interpolation"]], "rms() (in module unit_scaling.core.functional)": [[38, "unit_scaling.core.functional.rms"]], "scale_elementwise() (in module unit_scaling.core.functional)": [[39, "unit_scaling.core.functional.scale_elementwise"]], "transformer_residual_scaling_rule() (in module unit_scaling.core.functional)": [[40, "unit_scaling.core.functional.transformer_residual_scaling_rule"]], "unit_scaling.formats": [[41, "module-unit_scaling.formats"]], "fpformat (class in unit_scaling.formats)": [[42, "unit_scaling.formats.FPFormat"]], "bits (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.bits"]], "max_absolute_value (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.max_absolute_value"]], "min_absolute_normal (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.min_absolute_normal"]], "min_absolute_subnormal (unit_scaling.formats.fpformat property)": [[42, "unit_scaling.formats.FPFormat.min_absolute_subnormal"]], "quantise() (unit_scaling.formats.fpformat method)": [[42, "unit_scaling.formats.FPFormat.quantise"]], "quantise_bwd() (unit_scaling.formats.fpformat method)": [[42, "unit_scaling.formats.FPFormat.quantise_bwd"]], "quantise_fwd() (unit_scaling.formats.fpformat method)": [[42, "unit_scaling.formats.FPFormat.quantise_fwd"]], "format_to_tuple() (in module unit_scaling.formats)": [[43, "unit_scaling.formats.format_to_tuple"]], "tuple_to_format() (in module unit_scaling.formats)": [[44, "unit_scaling.formats.tuple_to_format"]], "unit_scaling.functional": [[45, "module-unit_scaling.functional"]], "add() (in module unit_scaling.functional)": [[46, "unit_scaling.functional.add"]], "cross_entropy() (in module unit_scaling.functional)": [[47, "unit_scaling.functional.cross_entropy"]], "dropout() (in module unit_scaling.functional)": [[48, "unit_scaling.functional.dropout"]], "embedding() (in module unit_scaling.functional)": [[49, "unit_scaling.functional.embedding"]], "gelu() (in module unit_scaling.functional)": [[50, "unit_scaling.functional.gelu"]], "layer_norm() (in module unit_scaling.functional)": [[51, "unit_scaling.functional.layer_norm"]], "linear() (in module unit_scaling.functional)": [[52, "unit_scaling.functional.linear"]], "linear_readout() (in module unit_scaling.functional)": [[53, "unit_scaling.functional.linear_readout"]], "matmul() (in module unit_scaling.functional)": [[54, "unit_scaling.functional.matmul"]], "mse_loss() (in module unit_scaling.functional)": [[55, "unit_scaling.functional.mse_loss"]], "residual_add() (in module unit_scaling.functional)": [[56, "unit_scaling.functional.residual_add"]], "residual_apply() (in module unit_scaling.functional)": [[57, "unit_scaling.functional.residual_apply"]], "residual_split() (in module unit_scaling.functional)": [[58, "unit_scaling.functional.residual_split"]], "rms_norm() (in module unit_scaling.functional)": [[59, "unit_scaling.functional.rms_norm"]], "scaled_dot_product_attention() (in module unit_scaling.functional)": [[60, "unit_scaling.functional.scaled_dot_product_attention"]], "silu() (in module unit_scaling.functional)": [[61, "unit_scaling.functional.silu"]], "silu_glu() (in module unit_scaling.functional)": [[62, "unit_scaling.functional.silu_glu"]], "softmax() (in module unit_scaling.functional)": [[63, "unit_scaling.functional.softmax"]], "unit_scaling.optim": [[64, "module-unit_scaling.optim"]], "adam (class in unit_scaling.optim)": [[65, "unit_scaling.optim.Adam"]], "add_param_group() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.add_param_group"]], "load_state_dict() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.load_state_dict"]], "register_load_state_dict_post_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_load_state_dict_post_hook"]], "register_load_state_dict_pre_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_load_state_dict_pre_hook"]], "register_state_dict_post_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_state_dict_post_hook"]], "register_state_dict_pre_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_state_dict_pre_hook"]], "register_step_post_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_step_post_hook"]], "register_step_pre_hook() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.register_step_pre_hook"]], "state_dict() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.state_dict"]], "step() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.step"]], "zero_grad() (unit_scaling.optim.adam method)": [[65, "unit_scaling.optim.Adam.zero_grad"]], "adamw (class in unit_scaling.optim)": [[66, "unit_scaling.optim.AdamW"]], "add_param_group() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.add_param_group"]], "load_state_dict() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.load_state_dict"]], "register_load_state_dict_post_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_load_state_dict_post_hook"]], "register_load_state_dict_pre_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_load_state_dict_pre_hook"]], "register_state_dict_post_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_state_dict_post_hook"]], "register_state_dict_pre_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_state_dict_pre_hook"]], "register_step_post_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_step_post_hook"]], "register_step_pre_hook() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.register_step_pre_hook"]], "state_dict() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.state_dict"]], "step() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.step"]], "zero_grad() (unit_scaling.optim.adamw method)": [[66, "unit_scaling.optim.AdamW.zero_grad"]], "sgd (class in unit_scaling.optim)": [[67, "unit_scaling.optim.SGD"]], "add_param_group() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.add_param_group"]], "load_state_dict() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.load_state_dict"]], "register_load_state_dict_post_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_load_state_dict_post_hook"]], "register_load_state_dict_pre_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_load_state_dict_pre_hook"]], "register_state_dict_post_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_state_dict_post_hook"]], "register_state_dict_pre_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_state_dict_pre_hook"]], "register_step_post_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_step_post_hook"]], "register_step_pre_hook() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.register_step_pre_hook"]], "state_dict() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.state_dict"]], "step() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.step"]], "zero_grad() (unit_scaling.optim.sgd method)": [[67, "unit_scaling.optim.SGD.zero_grad"]], "lr_scale_for_depth() (in module unit_scaling.optim)": [[68, "unit_scaling.optim.lr_scale_for_depth"]], "lr_scale_func_adam() (in module unit_scaling.optim)": [[69, "unit_scaling.optim.lr_scale_func_adam"]], "lr_scale_func_sgd() (in module unit_scaling.optim)": [[70, "unit_scaling.optim.lr_scale_func_sgd"]], "scaled_parameters() (in module unit_scaling.optim)": [[71, "unit_scaling.optim.scaled_parameters"]], "unit_scaling.parameter": [[72, "module-unit_scaling.parameter"]], "ordereddict (class in unit_scaling.parameter)": [[73, "unit_scaling.parameter.OrderedDict"]], "clear() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.clear"]], "copy() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.copy"]], "fromkeys() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.fromkeys"]], "get() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.get"]], "items() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.items"]], "keys() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.keys"]], "move_to_end() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.move_to_end"]], "pop() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.pop"]], "popitem() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.popitem"]], "setdefault() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.setdefault"]], "update() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.update"]], "values() (unit_scaling.parameter.ordereddict method)": [[73, "unit_scaling.parameter.OrderedDict.values"]], "parameter() (in module unit_scaling.parameter)": [[74, "unit_scaling.parameter.Parameter"]], "parameterdata (class in unit_scaling.parameter)": [[75, "unit_scaling.parameter.ParameterData"]], "protocol (class in unit_scaling.parameter)": [[76, "unit_scaling.parameter.Protocol"]], "h (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.H"]], "t (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.T"]], "tensor (class in unit_scaling.parameter)": [[77, "unit_scaling.parameter.Tensor"]], "abs() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.abs"]], "abs_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.abs_"]], "absolute() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.absolute"]], "absolute_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.absolute_"]], "acos() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acos"]], "acos_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acos_"]], "acosh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acosh"]], "acosh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.acosh_"]], "add() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.add"]], "add_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.add_"]], "addbmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addbmm"]], "addbmm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addbmm_"]], "addcdiv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcdiv"]], "addcdiv_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcdiv_"]], "addcmul() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcmul"]], "addcmul_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addcmul_"]], "addmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmm"]], "addmm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmm_"]], "addmv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmv"]], "addmv_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addmv_"]], "addr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addr"]], "addr_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.addr_"]], "adjoint() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.adjoint"]], "align_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.align_as"]], "align_to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.align_to"]], "all() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.all"]], "allclose() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.allclose"]], "amax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.amax"]], "amin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.amin"]], "aminmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.aminmax"]], "angle() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.angle"]], "any() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.any"]], "apply_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.apply_"]], "arccos() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccos"]], "arccos_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccos_"]], "arccosh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccosh"]], "arccosh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arccosh_"]], "arcsin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsin"]], "arcsin_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsin_"]], "arcsinh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsinh"]], "arcsinh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arcsinh_"]], "arctan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan"]], "arctan2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan2"]], "arctan2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan2_"]], "arctan_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctan_"]], "arctanh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctanh"]], "arctanh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.arctanh_"]], "argmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argmax"]], "argmin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argmin"]], "argsort() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argsort"]], "argwhere() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.argwhere"]], "as_strided() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_strided"]], "as_strided_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_strided_"]], "as_strided_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_strided_scatter"]], "as_subclass() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.as_subclass"]], "asin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asin"]], "asin_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asin_"]], "asinh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asinh"]], "asinh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.asinh_"]], "atan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan"]], "atan2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan2"]], "atan2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan2_"]], "atan_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atan_"]], "atanh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atanh"]], "atanh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.atanh_"]], "backward() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.backward"]], "baddbmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.baddbmm"]], "baddbmm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.baddbmm_"]], "bernoulli() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bernoulli"]], "bernoulli_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bernoulli_"]], "bfloat16() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bfloat16"]], "bincount() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bincount"]], "bitwise_and() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_and"]], "bitwise_and_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_and_"]], "bitwise_left_shift() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_left_shift"]], "bitwise_left_shift_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_left_shift_"]], "bitwise_not() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_not"]], "bitwise_not_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_not_"]], "bitwise_or() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_or"]], "bitwise_or_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_or_"]], "bitwise_right_shift() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_right_shift"]], "bitwise_right_shift_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_right_shift_"]], "bitwise_xor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_xor"]], "bitwise_xor_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bitwise_xor_"]], "bmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bmm"]], "bool() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.bool"]], "broadcast_to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.broadcast_to"]], "byte() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.byte"]], "cauchy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cauchy_"]], "cdouble() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cdouble"]], "ceil() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ceil"]], "ceil_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ceil_"]], "cfloat() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cfloat"]], "chalf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.chalf"]], "char() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.char"]], "cholesky() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cholesky"]], "cholesky_inverse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cholesky_inverse"]], "cholesky_solve() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cholesky_solve"]], "chunk() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.chunk"]], "clamp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clamp"]], "clamp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clamp_"]], "clip() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clip"]], "clip_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clip_"]], "clone() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.clone"]], "coalesce() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.coalesce"]], "col_indices() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.col_indices"]], "conj() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.conj"]], "conj_physical() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.conj_physical"]], "conj_physical_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.conj_physical_"]], "contiguous() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.contiguous"]], "copy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.copy_"]], "copysign() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.copysign"]], "copysign_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.copysign_"]], "corrcoef() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.corrcoef"]], "cos() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cos"]], "cos_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cos_"]], "cosh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cosh"]], "cosh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cosh_"]], "count_nonzero() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.count_nonzero"]], "cov() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cov"]], "cpu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cpu"]], "cross() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cross"]], "crow_indices() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.crow_indices"]], "cuda() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cuda"]], "cummax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cummax"]], "cummin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cummin"]], "cumprod() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumprod"]], "cumprod_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumprod_"]], "cumsum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumsum"]], "cumsum_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.cumsum_"]], "data_ptr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.data_ptr"]], "deg2rad() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.deg2rad"]], "deg2rad_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.deg2rad_"]], "dense_dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dense_dim"]], "dequantize() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dequantize"]], "det() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.det"]], "detach() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.detach"]], "detach_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.detach_"]], "device (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.device"]], "diag() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diag"]], "diag_embed() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diag_embed"]], "diagflat() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diagflat"]], "diagonal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diagonal"]], "diagonal_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diagonal_scatter"]], "diff() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.diff"]], "digamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.digamma"]], "digamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.digamma_"]], "dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dim"]], "dim_order() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dim_order"]], "dist() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dist"]], "div() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.div"]], "div_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.div_"]], "divide() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.divide"]], "divide_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.divide_"]], "dot() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dot"]], "double() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.double"]], "dsplit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.dsplit"]], "element_size() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.element_size"]], "eq() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.eq"]], "eq_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.eq_"]], "equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.equal"]], "erf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erf"]], "erf_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erf_"]], "erfc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfc"]], "erfc_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfc_"]], "erfinv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfinv"]], "erfinv_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.erfinv_"]], "exp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp"]], "exp2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp2"]], "exp2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp2_"]], "exp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exp_"]], "expand() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expand"]], "expand_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expand_as"]], "expm1() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expm1"]], "expm1_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.expm1_"]], "exponential_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.exponential_"]], "fill_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fill_"]], "fill_diagonal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fill_diagonal_"]], "fix() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fix"]], "fix_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fix_"]], "flatten() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.flatten"]], "flip() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.flip"]], "fliplr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fliplr"]], "flipud() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.flipud"]], "float() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.float"]], "float_power() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.float_power"]], "float_power_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.float_power_"]], "floor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor"]], "floor_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor_"]], "floor_divide() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor_divide"]], "floor_divide_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.floor_divide_"]], "fmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmax"]], "fmin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmin"]], "fmod() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmod"]], "fmod_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.fmod_"]], "frac() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.frac"]], "frac_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.frac_"]], "frexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.frexp"]], "gather() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gather"]], "gcd() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gcd"]], "gcd_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gcd_"]], "ge() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ge"]], "ge_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ge_"]], "geometric_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.geometric_"]], "geqrf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.geqrf"]], "ger() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ger"]], "get_device() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.get_device"]], "grad (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.grad"]], "greater() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater"]], "greater_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater_"]], "greater_equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater_equal"]], "greater_equal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.greater_equal_"]], "gt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gt"]], "gt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.gt_"]], "half() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.half"]], "hardshrink() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hardshrink"]], "has_names() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.has_names"]], "heaviside() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.heaviside"]], "heaviside_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.heaviside_"]], "histc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.histc"]], "histogram() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.histogram"]], "hsplit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hsplit"]], "hypot() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hypot"]], "hypot_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.hypot_"]], "i0() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.i0"]], "i0_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.i0_"]], "igamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igamma"]], "igamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igamma_"]], "igammac() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igammac"]], "igammac_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.igammac_"]], "imag (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.imag"]], "index_add() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_add"]], "index_add_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_add_"]], "index_copy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_copy"]], "index_copy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_copy_"]], "index_fill() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_fill"]], "index_fill_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_fill_"]], "index_put() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_put"]], "index_put_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_put_"]], "index_reduce_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_reduce_"]], "index_select() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.index_select"]], "indices() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.indices"]], "inner() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.inner"]], "int() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.int"]], "int_repr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.int_repr"]], "inverse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.inverse"]], "ipu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ipu"]], "is_coalesced() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_coalesced"]], "is_complex() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_complex"]], "is_conj() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_conj"]], "is_contiguous() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_contiguous"]], "is_cpu (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_cpu"]], "is_cuda (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_cuda"]], "is_floating_point() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_floating_point"]], "is_inference() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_inference"]], "is_ipu (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_ipu"]], "is_leaf (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_leaf"]], "is_meta (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_meta"]], "is_mps (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_mps"]], "is_neg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_neg"]], "is_pinned() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_pinned"]], "is_quantized (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_quantized"]], "is_set_to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_set_to"]], "is_shared() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_shared"]], "is_signed() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.is_signed"]], "is_sparse (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_sparse"]], "is_sparse_csr (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_sparse_csr"]], "is_xla (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_xla"]], "is_xpu (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.is_xpu"]], "isclose() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isclose"]], "isfinite() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isfinite"]], "isinf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isinf"]], "isnan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isnan"]], "isneginf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isneginf"]], "isposinf() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isposinf"]], "isreal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.isreal"]], "istft() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.istft"]], "item() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.item"]], "itemsize (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.itemsize"]], "kron() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.kron"]], "kthvalue() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.kthvalue"]], "lcm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lcm"]], "lcm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lcm_"]], "ldexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ldexp"]], "ldexp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ldexp_"]], "le() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.le"]], "le_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.le_"]], "lerp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lerp"]], "lerp_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lerp_"]], "less() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less"]], "less_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less_"]], "less_equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less_equal"]], "less_equal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.less_equal_"]], "lgamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lgamma"]], "lgamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lgamma_"]], "log() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log"]], "log10() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log10"]], "log10_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log10_"]], "log1p() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log1p"]], "log1p_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log1p_"]], "log2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log2"]], "log2_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log2_"]], "log_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log_"]], "log_normal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.log_normal_"]], "logaddexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logaddexp"]], "logaddexp2() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logaddexp2"]], "logcumsumexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logcumsumexp"]], "logdet() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logdet"]], "logical_and() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_and"]], "logical_and_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_and_"]], "logical_not() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_not"]], "logical_not_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_not_"]], "logical_or() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_or"]], "logical_or_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_or_"]], "logical_xor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_xor"]], "logical_xor_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logical_xor_"]], "logit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logit"]], "logit_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logit_"]], "logsumexp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.logsumexp"]], "long() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.long"]], "lt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lt"]], "lt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lt_"]], "lu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lu"]], "lu_solve() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.lu_solve"]], "mh (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.mH"]], "mt (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.mT"]], "map_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.map_"]], "masked_fill() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_fill"]], "masked_fill_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_fill_"]], "masked_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_scatter"]], "masked_scatter_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_scatter_"]], "masked_select() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.masked_select"]], "matmul() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.matmul"]], "matrix_exp() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.matrix_exp"]], "matrix_power() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.matrix_power"]], "max() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.max"]], "maximum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.maximum"]], "mean() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mean"]], "median() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.median"]], "min() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.min"]], "minimum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.minimum"]], "mm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mm"]], "mode() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mode"]], "module_load() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.module_load"]], "moveaxis() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.moveaxis"]], "movedim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.movedim"]], "msort() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.msort"]], "mul() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mul"]], "mul_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mul_"]], "multinomial() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.multinomial"]], "multiply() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.multiply"]], "multiply_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.multiply_"]], "mv() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mv"]], "mvlgamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mvlgamma"]], "mvlgamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.mvlgamma_"]], "names (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.names"]], "nan_to_num() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nan_to_num"]], "nan_to_num_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nan_to_num_"]], "nanmean() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nanmean"]], "nanmedian() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nanmedian"]], "nanquantile() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nanquantile"]], "nansum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nansum"]], "narrow() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.narrow"]], "narrow_copy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.narrow_copy"]], "nbytes (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.nbytes"]], "ndim (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.ndim"]], "ndimension() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ndimension"]], "ne() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ne"]], "ne_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ne_"]], "neg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.neg"]], "neg_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.neg_"]], "negative() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.negative"]], "negative_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.negative_"]], "nelement() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nelement"]], "new_empty() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_empty"]], "new_empty_strided() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_empty_strided"]], "new_full() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_full"]], "new_ones() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_ones"]], "new_tensor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_tensor"]], "new_zeros() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.new_zeros"]], "nextafter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nextafter"]], "nextafter_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nextafter_"]], "nonzero() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nonzero"]], "nonzero_static() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.nonzero_static"]], "norm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.norm"]], "normal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.normal_"]], "not_equal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.not_equal"]], "not_equal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.not_equal_"]], "numel() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.numel"]], "numpy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.numpy"]], "orgqr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.orgqr"]], "ormqr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ormqr"]], "outer() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.outer"]], "permute() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.permute"]], "pin_memory() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pin_memory"]], "pinverse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pinverse"]], "polygamma() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.polygamma"]], "polygamma_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.polygamma_"]], "positive() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.positive"]], "pow() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pow"]], "pow_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.pow_"]], "prod() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.prod"]], "put() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.put"]], "put_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.put_"]], "q_per_channel_axis() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_per_channel_axis"]], "q_per_channel_scales() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_per_channel_scales"]], "q_per_channel_zero_points() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_per_channel_zero_points"]], "q_scale() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_scale"]], "q_zero_point() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.q_zero_point"]], "qr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.qr"]], "qscheme() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.qscheme"]], "quantile() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.quantile"]], "rad2deg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rad2deg"]], "rad2deg_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rad2deg_"]], "random_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.random_"]], "ravel() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.ravel"]], "real (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.real"]], "reciprocal() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reciprocal"]], "reciprocal_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reciprocal_"]], "record_stream() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.record_stream"]], "refine_names() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.refine_names"]], "register_hook() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.register_hook"]], "register_post_accumulate_grad_hook() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.register_post_accumulate_grad_hook"]], "remainder() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.remainder"]], "remainder_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.remainder_"]], "rename() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rename"]], "rename_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rename_"]], "renorm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.renorm"]], "renorm_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.renorm_"]], "repeat() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.repeat"]], "repeat_interleave() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.repeat_interleave"]], "requires_grad (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.requires_grad"]], "requires_grad_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.requires_grad_"]], "reshape() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reshape"]], "reshape_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.reshape_as"]], "resize_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resize_"]], "resize_as_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resize_as_"]], "resolve_conj() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resolve_conj"]], "resolve_neg() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.resolve_neg"]], "retain_grad() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.retain_grad"]], "retains_grad (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.retains_grad"]], "roll() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.roll"]], "rot90() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rot90"]], "round() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.round"]], "round_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.round_"]], "rsqrt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rsqrt"]], "rsqrt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.rsqrt_"]], "scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter"]], "scatter_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_"]], "scatter_add() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_add"]], "scatter_add_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_add_"]], "scatter_reduce() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_reduce"]], "scatter_reduce_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.scatter_reduce_"]], "select() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.select"]], "select_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.select_scatter"]], "set_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.set_"]], "sgn() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sgn"]], "sgn_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sgn_"]], "shape (unit_scaling.parameter.tensor attribute)": [[77, "unit_scaling.parameter.Tensor.shape"]], "share_memory_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.share_memory_"]], "short() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.short"]], "sigmoid() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sigmoid"]], "sigmoid_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sigmoid_"]], "sign() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sign"]], "sign_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sign_"]], "signbit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.signbit"]], "sin() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sin"]], "sin_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sin_"]], "sinc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinc"]], "sinc_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinc_"]], "sinh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinh"]], "sinh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sinh_"]], "size() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.size"]], "slice_scatter() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.slice_scatter"]], "slogdet() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.slogdet"]], "smm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.smm"]], "softmax() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.softmax"]], "sort() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sort"]], "sparse_dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_dim"]], "sparse_mask() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_mask"]], "sparse_resize_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_resize_"]], "sparse_resize_and_clear_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sparse_resize_and_clear_"]], "split() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.split"]], "sqrt() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sqrt"]], "sqrt_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sqrt_"]], "square() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.square"]], "square_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.square_"]], "squeeze() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.squeeze"]], "squeeze_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.squeeze_"]], "sspaddmm() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sspaddmm"]], "std() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.std"]], "stft() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.stft"]], "storage() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.storage"]], "storage_offset() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.storage_offset"]], "storage_type() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.storage_type"]], "stride() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.stride"]], "sub() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sub"]], "sub_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sub_"]], "subtract() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.subtract"]], "subtract_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.subtract_"]], "sum() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sum"]], "sum_to_size() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.sum_to_size"]], "svd() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.svd"]], "swapaxes() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapaxes"]], "swapaxes_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapaxes_"]], "swapdims() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapdims"]], "swapdims_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.swapdims_"]], "t() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.t"]], "t_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.t_"]], "take() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.take"]], "take_along_dim() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.take_along_dim"]], "tan() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tan"]], "tan_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tan_"]], "tanh() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tanh"]], "tanh_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tanh_"]], "tensor_split() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tensor_split"]], "tile() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tile"]], "to() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to"]], "to_dense() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_dense"]], "to_mkldnn() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_mkldnn"]], "to_padded_tensor() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_padded_tensor"]], "to_sparse() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse"]], "to_sparse_bsc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_bsc"]], "to_sparse_bsr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_bsr"]], "to_sparse_coo() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_coo"]], "to_sparse_csc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_csc"]], "to_sparse_csr() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.to_sparse_csr"]], "tolist() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tolist"]], "topk() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.topk"]], "trace() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.trace"]], "transpose() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.transpose"]], "transpose_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.transpose_"]], "triangular_solve() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.triangular_solve"]], "tril() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tril"]], "tril_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.tril_"]], "triu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.triu"]], "triu_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.triu_"]], "true_divide() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.true_divide"]], "true_divide_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.true_divide_"]], "trunc() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.trunc"]], "trunc_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.trunc_"]], "type() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.type"]], "type_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.type_as"]], "unbind() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unbind"]], "unflatten() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unflatten"]], "unfold() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unfold"]], "uniform_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.uniform_"]], "unique() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unique"]], "unique_consecutive() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unique_consecutive"]], "unsafe_chunk() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsafe_chunk"]], "unsafe_split() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsafe_split"]], "unsqueeze() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsqueeze"]], "unsqueeze_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.unsqueeze_"]], "untyped_storage() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.untyped_storage"]], "values() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.values"]], "var() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.var"]], "vdot() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.vdot"]], "view() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.view"]], "view_as() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.view_as"]], "vsplit() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.vsplit"]], "where() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.where"]], "xlogy() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.xlogy"]], "xlogy_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.xlogy_"]], "xpu() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.xpu"]], "zero_() (unit_scaling.parameter.tensor method)": [[77, "unit_scaling.parameter.Tensor.zero_"]], "has_parameter_data() (in module unit_scaling.parameter)": [[78, "unit_scaling.parameter.has_parameter_data"]], "unit_scaling.scale": [[79, "module-unit_scaling.scale"]], "scale_bwd() (in module unit_scaling.scale)": [[80, "unit_scaling.scale.scale_bwd"]], "scale_fwd() (in module unit_scaling.scale)": [[81, "unit_scaling.scale.scale_fwd"]], "transformer_residual_scaling_rule() (in module unit_scaling)": [[82, "unit_scaling.transformer_residual_scaling_rule"]], "unit_scaling.transforms": [[83, "module-unit_scaling.transforms"]], "metrics (class in unit_scaling.transforms)": [[84, "unit_scaling.transforms.Metrics"]], "metrics.data (class in unit_scaling.transforms)": [[84, "unit_scaling.transforms.Metrics.Data"]], "compile() (in module unit_scaling.transforms)": [[85, "unit_scaling.transforms.compile"]], "prune_non_float_tensors() (in module unit_scaling.transforms)": [[86, "unit_scaling.transforms.prune_non_float_tensors"]], "prune_same_scale_tensors() (in module unit_scaling.transforms)": [[87, "unit_scaling.transforms.prune_same_scale_tensors"]], "prune_selected_nodes() (in module unit_scaling.transforms)": [[88, "unit_scaling.transforms.prune_selected_nodes"]], "simulate_format() (in module unit_scaling.transforms)": [[89, "unit_scaling.transforms.simulate_format"]], "simulate_fp8() (in module unit_scaling.transforms)": [[90, "unit_scaling.transforms.simulate_fp8"]], "track_scales() (in module unit_scaling.transforms)": [[91, "unit_scaling.transforms.track_scales"]], "unit_scale() (in module unit_scaling.transforms)": [[92, "unit_scaling.transforms.unit_scale"]], "unit_scaling.transforms.utils": [[93, "module-unit_scaling.transforms.utils"]], "apply_transform() (in module unit_scaling.transforms.utils)": [[94, "unit_scaling.transforms.utils.apply_transform"]], "patch_to_expand_modules() (in module unit_scaling.transforms.utils)": [[95, "unit_scaling.transforms.utils.patch_to_expand_modules"]], "replace_node_with_function() (in module unit_scaling.transforms.utils)": [[96, "unit_scaling.transforms.utils.replace_node_with_function"]], "torch_nn_modules_to_user_modules() (in module unit_scaling.transforms.utils)": [[97, "unit_scaling.transforms.utils.torch_nn_modules_to_user_modules"]], "unit_scaling.utils": [[98, "module-unit_scaling.utils"]], "scalepair (class in unit_scaling.utils)": [[99, "unit_scaling.utils.ScalePair"]], "scaletracker (class in unit_scaling.utils)": [[100, "unit_scaling.utils.ScaleTracker"]], "backward() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.backward"]], "jvp() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.jvp"]], "mark_dirty() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.mark_dirty"]], "mark_non_differentiable() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.mark_non_differentiable"]], "save_for_backward() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.save_for_backward"]], "save_for_forward() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.save_for_forward"]], "set_materialize_grads() (unit_scaling.utils.scaletracker method)": [[100, "unit_scaling.utils.ScaleTracker.set_materialize_grads"]], "setup_context() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.setup_context"]], "vjp() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.vjp"]], "vmap() (unit_scaling.utils.scaletracker static method)": [[100, "unit_scaling.utils.ScaleTracker.vmap"]], "scaletrackinginterpreter (class in unit_scaling.utils)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter"]], "boxed_run() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.boxed_run"]], "call_function() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.call_function"]], "call_method() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.call_method"]], "call_module() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.call_module"]], "fetch_args_kwargs_from_env() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.fetch_args_kwargs_from_env"]], "fetch_attr() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.fetch_attr"]], "get_attr() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.get_attr"]], "map_nodes_to_values() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.map_nodes_to_values"]], "output() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.output"]], "placeholder() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.placeholder"]], "run() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.run"]], "run_node() (unit_scaling.utils.scaletrackinginterpreter method)": [[101, "unit_scaling.utils.ScaleTrackingInterpreter.run_node"]], "analyse_module() (in module unit_scaling.utils)": [[102, "unit_scaling.utils.analyse_module"]], "visualiser() (in module unit_scaling)": [[103, "unit_scaling.visualiser"]]}})
\ No newline at end of file
diff --git a/user_guide.html b/user_guide.html
index 197217e..31353fa 100644
--- a/user_guide.html
+++ b/user_guide.html
@@ -18,7 +18,7 @@
-
+
@@ -433,7 +433,7 @@ 1.5. Optimising unit-scaled models