diff --git a/MaxText/convert_gemma2_chkpt.py b/MaxText/convert_gemma2_chkpt.py index 33ddf73e2..a89746694 100644 --- a/MaxText/convert_gemma2_chkpt.py +++ b/MaxText/convert_gemma2_chkpt.py @@ -33,7 +33,7 @@ import orbax import checkpointing -from train import save_checkpoint +from MaxText.train import save_checkpoint Params = dict[str, Any] diff --git a/MaxText/convert_gemma_chkpt.py b/MaxText/convert_gemma_chkpt.py index 38881ac43..34f95fc2c 100644 --- a/MaxText/convert_gemma_chkpt.py +++ b/MaxText/convert_gemma_chkpt.py @@ -33,7 +33,7 @@ import orbax import checkpointing -from train import save_checkpoint +from MaxText.train import save_checkpoint Params = dict[str, Any] diff --git a/MaxText/convert_gpt3_ckpt_from_paxml.py b/MaxText/convert_gpt3_ckpt_from_paxml.py index 0f6d6111c..d9985a3e7 100644 --- a/MaxText/convert_gpt3_ckpt_from_paxml.py +++ b/MaxText/convert_gpt3_ckpt_from_paxml.py @@ -50,7 +50,7 @@ import gc import max_logging from psutil import Process -from train import save_checkpoint +from MaxText.train import save_checkpoint import argparse diff --git a/MaxText/generate_param_only_checkpoint.py b/MaxText/generate_param_only_checkpoint.py index a8c8fddfc..6e62a00b5 100644 --- a/MaxText/generate_param_only_checkpoint.py +++ b/MaxText/generate_param_only_checkpoint.py @@ -36,7 +36,7 @@ from jax import random from typing import Sequence from layers import models, quantizations -from train import save_checkpoint +from MaxText.train import save_checkpoint Transformer = models.Transformer diff --git a/MaxText/inference_microbenchmark.py b/MaxText/inference_microbenchmark.py index 3254faf05..a21305cce 100644 --- a/MaxText/inference_microbenchmark.py +++ b/MaxText/inference_microbenchmark.py @@ -25,7 +25,7 @@ from jetstream.engine import token_utils -import max_utils +from MaxText import max_utils import maxengine import maxtext_utils import profiler diff --git a/MaxText/input_pipeline/_tfds_data_processing.py b/MaxText/input_pipeline/_tfds_data_processing.py index 279765b3c..c7cc644ae 100644 --- a/MaxText/input_pipeline/_tfds_data_processing.py +++ b/MaxText/input_pipeline/_tfds_data_processing.py @@ -24,10 +24,10 @@ import tensorflow_datasets as tfds import jax -import multihost_dataloading -import tokenizer -import sequence_packing -from input_pipeline import _input_pipeline_utils +from MaxText import multihost_dataloading +from MaxText import tokenizer +from MaxText import sequence_packing +from MaxText.input_pipeline import _input_pipeline_utils AUTOTUNE = tf.data.experimental.AUTOTUNE diff --git a/MaxText/kernels/ragged_attention.py b/MaxText/kernels/ragged_attention.py index 8ddeb7214..381f416dd 100644 --- a/MaxText/kernels/ragged_attention.py +++ b/MaxText/kernels/ragged_attention.py @@ -24,7 +24,7 @@ from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np -import common_types +from MaxText import common_types from jax.experimental import shard_map diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 784bd5f28..b07fb73be 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -27,13 +27,13 @@ from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask import jax.numpy as jnp -import common_types -from kernels.ragged_attention import ragged_gqa -from kernels.ragged_attention import ragged_mha -from layers import embeddings -from layers import initializers -from layers import linears -from layers import quantizations +from MaxText import common_types +from MaxText.kernels.ragged_attention import ragged_gqa +from MaxText.kernels.ragged_attention import ragged_mha +from MaxText.layers import embeddings +from MaxText.layers import initializers +from MaxText.layers import linears +from MaxText.layers import quantizations # pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index 38ed8a903..950eda28d 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -20,7 +20,7 @@ import jax from jax import lax import jax.numpy as jnp -from layers import initializers +from MaxText.layers import initializers Config = Any Array = jnp.ndarray diff --git a/MaxText/layers/initializers.py b/MaxText/layers/initializers.py index 5916ecb0c..bf915e757 100644 --- a/MaxText/layers/initializers.py +++ b/MaxText/layers/initializers.py @@ -18,7 +18,7 @@ from flax import linen as nn import jax -import common_types +from MaxText import common_types Array = common_types.Array DType = common_types.DType diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 10d1d0452..43725b466 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -23,18 +23,18 @@ import jax from jax import lax import jax.numpy as jnp -import common_types -from layers import initializers -from layers import normalizations -from layers import quantizations +from MaxText import common_types +from MaxText.layers import initializers +from MaxText.layers import normalizations +from MaxText.layers import quantizations import numpy as np from jax.ad_checkpoint import checkpoint_name from jax.experimental import shard_map import math -import max_logging -import max_utils +from MaxText import max_logging +from MaxText import max_utils from aqt.jax.v2 import aqt_tensor -from kernels import megablox as mblx +from MaxText.kernels import megablox as mblx Array = common_types.Array diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 604ccc730..27bb9d706 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -23,14 +23,14 @@ import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name # from jax.experimental.pallas.ops.tpu import flash_attention -from layers import attentions -from layers import embeddings -from layers import linears -from layers import normalizations -from layers import models -from layers import quantizations - -import common_types +from MaxText.layers import attentions +from MaxText.layers import embeddings +from MaxText.layers import linears +from MaxText.layers import normalizations +from MaxText.layers import models +from MaxText.layers import quantizations + +from MaxText import common_types from typing import Optional Array = common_types.Array diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 4c2046c1f..98688d9d5 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -24,12 +24,12 @@ import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name -import common_types -from layers import attentions -from layers import embeddings -from layers import linears -from layers import normalizations, quantizations -from layers import pipeline +from MaxText import common_types +from MaxText.layers import attentions +from MaxText.layers import embeddings +from MaxText.layers import linears +from MaxText.layers import normalizations, quantizations +from MaxText.layers import pipeline Array = common_types.Array Config = common_types.Config diff --git a/MaxText/layers/normalizations.py b/MaxText/layers/normalizations.py index 862c586c9..05377e564 100644 --- a/MaxText/layers/normalizations.py +++ b/MaxText/layers/normalizations.py @@ -19,7 +19,7 @@ from flax import linen as nn from jax import lax import jax.numpy as jnp -from layers import initializers +from MaxText.layers import initializers Initializer = initializers.Initializer diff --git a/MaxText/layers/pipeline.py b/MaxText/layers/pipeline.py index 3ce43461b..eb12970d3 100644 --- a/MaxText/layers/pipeline.py +++ b/MaxText/layers/pipeline.py @@ -20,7 +20,7 @@ from jax import numpy as jnp from flax.core import meta from flax import linen as nn -import common_types +from MaxText import common_types import functools from typing import Any diff --git a/MaxText/layers/quantizations.py b/MaxText/layers/quantizations.py index 00cca229f..5ebafd937 100644 --- a/MaxText/layers/quantizations.py +++ b/MaxText/layers/quantizations.py @@ -23,7 +23,7 @@ from aqt.jax.v2.flax import aqt_flax from aqt.jax.v2 import tiled_dot_general from aqt.jax.v2 import calibration -import common_types +from MaxText import common_types from dataclasses import dataclass import flax.linen as nn import jax diff --git a/MaxText/llama_ckpt_conversion_inference_only.py b/MaxText/llama_ckpt_conversion_inference_only.py index c9d74161f..15469f020 100644 --- a/MaxText/llama_ckpt_conversion_inference_only.py +++ b/MaxText/llama_ckpt_conversion_inference_only.py @@ -37,7 +37,7 @@ import jax from flax.training import train_state import max_logging -from train import save_checkpoint +from MaxText.train import save_checkpoint import torch import sys import os diff --git a/MaxText/llama_or_mistral_ckpt.py b/MaxText/llama_or_mistral_ckpt.py index fadcbffa3..cdcd461df 100644 --- a/MaxText/llama_or_mistral_ckpt.py +++ b/MaxText/llama_or_mistral_ckpt.py @@ -49,7 +49,7 @@ from tqdm import tqdm import max_logging -from train import save_checkpoint +from MaxText.train import save_checkpoint import checkpointing diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index b47bf29a4..dc8f04af3 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -19,8 +19,7 @@ import jax import jax.numpy as jnp from jax.experimental import mesh_utils -import checkpointing -import common_types +from MaxText import checkpointing, common_types import functools import time import optax diff --git a/MaxText/maxengine.py b/MaxText/maxengine.py index fca2fdb9e..24645bca0 100644 --- a/MaxText/maxengine.py +++ b/MaxText/maxengine.py @@ -22,13 +22,13 @@ from flax.linen import partitioning as nn_partitioning from flax import struct -from layers import models, quantizations +from MaxText.layers import models, quantizations import jax import jax.numpy as jnp from jax.sharding import PartitionSpec as P -import common_types +from MaxText import common_types from jetstream.core import config_lib from jetstream.engine import engine_api from jetstream.engine import tokenizer_pb2 diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index 14abcd568..e6643ee7f 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -19,7 +19,7 @@ import jax import optax -import max_utils +from MaxText import max_utils from jax.sharding import PartitionSpec as P from jax.experimental.serialize_executable import deserialize_and_load diff --git a/MaxText/standalone_checkpointer.py b/MaxText/standalone_checkpointer.py index 61864efdc..64333ac05 100644 --- a/MaxText/standalone_checkpointer.py +++ b/MaxText/standalone_checkpointer.py @@ -30,11 +30,11 @@ from jax import numpy as jnp import numpy as np -import checkpointing +from MaxText import checkpointing import max_utils import max_logging import pyconfig -from train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint +from MaxText.train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint from layers import models diff --git a/MaxText/standalone_dataloader.py b/MaxText/standalone_dataloader.py index 5e7e3447f..710d4c3f7 100644 --- a/MaxText/standalone_dataloader.py +++ b/MaxText/standalone_dataloader.py @@ -27,7 +27,7 @@ import numpy as np import pyconfig -from train import validate_train_config, get_first_step, load_next_batch, setup_train_loop +from MaxText.train import validate_train_config, get_first_step, load_next_batch, setup_train_loop def data_load_loop(config, state=None): diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index f04ebe190..4c186b5ca 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -19,18 +19,16 @@ import sys import unittest -import common_types - +from MaxText import common_types, max_utils from flax.core import freeze import jax import jax.numpy as jnp -import max_utils -import numpy as np + import pytest import pyconfig -from layers import attentions +from MaxText.layers import attentions Mesh = jax.sharding.Mesh Attention = attentions.Attention diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index b1f0bed52..ee1c39a7c 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -18,11 +18,9 @@ import sys import jax import unittest -import max_utils +from MaxText import max_utils from jax.sharding import Mesh -from layers import models -from layers import embeddings -from layers import quantizations +from MaxText.layers import models, embeddings, quantizations import jax.numpy as jnp diff --git a/MaxText/tests/gradient_accumulation_test.py b/MaxText/tests/gradient_accumulation_test.py index e2730fe97..f684ade3a 100644 --- a/MaxText/tests/gradient_accumulation_test.py +++ b/MaxText/tests/gradient_accumulation_test.py @@ -18,7 +18,7 @@ import pytest import string import random -from train import main as train_main +from MaxText.train import main as train_main def generate_random_string(length=10): diff --git a/MaxText/tests/grain_data_processing_test.py b/MaxText/tests/grain_data_processing_test.py index d2034876d..8bd849c40 100644 --- a/MaxText/tests/grain_data_processing_test.py +++ b/MaxText/tests/grain_data_processing_test.py @@ -23,8 +23,8 @@ import unittest import pyconfig -from input_pipeline import _grain_data_processing -from input_pipeline import input_pipeline_interface +from MaxText.input_pipeline import _grain_data_processing +from MaxText.input_pipeline import input_pipeline_interface class GrainDataProcessingTest(unittest.TestCase): diff --git a/MaxText/tests/hf_data_processing_test.py b/MaxText/tests/hf_data_processing_test.py index c0b0002d2..a51f15897 100644 --- a/MaxText/tests/hf_data_processing_test.py +++ b/MaxText/tests/hf_data_processing_test.py @@ -22,8 +22,8 @@ import unittest import pyconfig -from input_pipeline import _hf_data_processing -from input_pipeline import input_pipeline_interface +from MaxText.input_pipeline import _hf_data_processing +from MaxText.input_pipeline import input_pipeline_interface class HfDataProcessingTest(unittest.TestCase): diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index c28de3dcc..d1ae31561 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -18,7 +18,7 @@ import pytest import unittest from absl.testing import absltest -from inference_microbenchmark import main as inference_microbenchmark_main +from MaxText.inference_microbenchmark import main as inference_microbenchmark_main class Inference_Microbenchmark(unittest.TestCase): diff --git a/MaxText/tests/kernels_test.py b/MaxText/tests/kernels_test.py index 5ec2d1c17..605055823 100644 --- a/MaxText/tests/kernels_test.py +++ b/MaxText/tests/kernels_test.py @@ -21,7 +21,7 @@ import unittest import jax import jax.numpy as jnp -from kernels.ragged_attention import ragged_mqa, reference_mqa, ragged_mha, reference_mha, ragged_gqa, reference_gqa +from MaxText.kernels.ragged_attention import ragged_mqa, reference_mqa, ragged_mha, reference_mha, ragged_gqa, reference_gqa class RaggedAttentionTest(unittest.TestCase): diff --git a/MaxText/tests/llama_test.py b/MaxText/tests/llama_test.py index 9c9fa08b4..dca535708 100644 --- a/MaxText/tests/llama_test.py +++ b/MaxText/tests/llama_test.py @@ -19,7 +19,7 @@ import unittest import jax.numpy as jnp from typing import Tuple -from layers.llama2 import embeddings +from MaxText.layers.llama2 import embeddings import numpy as np diff --git a/MaxText/tests/max_utils_test.py b/MaxText/tests/max_utils_test.py index 163f7c591..3e737edd1 100644 --- a/MaxText/tests/max_utils_test.py +++ b/MaxText/tests/max_utils_test.py @@ -16,7 +16,7 @@ """ Tests for the common Max Utils """ import jax -import max_utils +from MaxText import max_utils from flax import linen as nn from flax.training import train_state from jax import numpy as jnp @@ -25,8 +25,8 @@ import optax import pyconfig import unittest -from layers import models -from layers import quantizations +from MaxText.layers import models +from MaxText.layers import quantizations Transformer = models.Transformer diff --git a/MaxText/tests/maxengine_test.py b/MaxText/tests/maxengine_test.py index c59e11b3e..a9a8f3929 100644 --- a/MaxText/tests/maxengine_test.py +++ b/MaxText/tests/maxengine_test.py @@ -21,7 +21,7 @@ import numpy as np import unittest import pyconfig -from maxengine import MaxEngine +from MaxText.maxengine import MaxEngine class MaxEngineTest(unittest.TestCase): diff --git a/MaxText/tests/maxtext_utils_test.py b/MaxText/tests/maxtext_utils_test.py index 9362a5326..03e45d7c0 100644 --- a/MaxText/tests/maxtext_utils_test.py +++ b/MaxText/tests/maxtext_utils_test.py @@ -18,7 +18,7 @@ import unittest import jax.numpy as jnp -import maxtext_utils +from MaxText import maxtext_utils class TestGradientClipping(unittest.TestCase): diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 9791af93b..3f7d76d19 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -15,12 +15,12 @@ import sys import unittest -import common_types +from MaxText import common_types from flax.core import freeze import jax import jax.numpy as jnp -import max_utils +from MaxText import max_utils import numpy as np import pytest diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index 4c132cf69..3d15b76a2 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -14,14 +14,13 @@ import jax import unittest -from layers import linears -from layers import initializers +from MaxText.layers import linears +from MaxText.layers import initializers import jax.numpy as jnp import pyconfig -import max_utils +from MaxText import max_utils from jax.sharding import Mesh -import flax.linen as nn class TokenDroppingTest(unittest.TestCase): diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index ba289c040..5b46c1271 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -27,7 +27,7 @@ import pytest import pyconfig -import multihost_dataloading +from MaxText import multihost_dataloading class MultihostDataloadingTest(unittest.TestCase): diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 193c677fa..8ce1d8515 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -24,20 +24,20 @@ import pyconfig -from layers import pipeline +from MaxText.layers import pipeline import jax from jax import numpy as jnp from jax.sharding import Mesh -import common_types +from MaxText import common_types import pyconfig -import max_utils +from MaxText import max_utils from flax.core import meta import jax.numpy as jnp from flax import linen as nn from layers import simple_layer -from train import main as train_main +from MaxText.train import main as train_main def assert_same_output_and_grad(f1, f2, *inputs): diff --git a/MaxText/tests/quantizations_test.py b/MaxText/tests/quantizations_test.py index 2218c3449..827b567be 100644 --- a/MaxText/tests/quantizations_test.py +++ b/MaxText/tests/quantizations_test.py @@ -21,7 +21,7 @@ import functools import numpy as np import pyconfig -from layers import quantizations +from MaxText.layers import quantizations import unittest from aqt.jax.v2 import aqt_tensor from aqt.jax.v2 import calibration diff --git a/MaxText/tests/simple_decoder_layer_test.py b/MaxText/tests/simple_decoder_layer_test.py index afa2d0aeb..05f577ad9 100644 --- a/MaxText/tests/simple_decoder_layer_test.py +++ b/MaxText/tests/simple_decoder_layer_test.py @@ -13,7 +13,7 @@ import unittest import pytest -from train import main as train_main +from MaxText.train import main as train_main class SimpleDecoderLayerTest(unittest.TestCase): diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 1bd774946..45a4b6c6b 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -17,8 +17,8 @@ """ Tests for the standalone_checkpointer.py """ import unittest import pytest -from standalone_checkpointer import main as sckpt_main -from standalone_dataloader import main as sdl_main +from MaxText.standalone_checkpointer import main as sckpt_main +from MaxText.standalone_dataloader import main as sdl_main from datetime import datetime import random import string diff --git a/MaxText/tests/tfds_data_processing_test.py b/MaxText/tests/tfds_data_processing_test.py index a89d99dd2..9de2840d3 100644 --- a/MaxText/tests/tfds_data_processing_test.py +++ b/MaxText/tests/tfds_data_processing_test.py @@ -26,8 +26,8 @@ import tensorflow_datasets as tfds import pyconfig -from input_pipeline import _tfds_data_processing -from input_pipeline import input_pipeline_interface +from MaxText.input_pipeline import _tfds_data_processing +from MaxText.input_pipeline import input_pipeline_interface class TfdsDataProcessingTest(unittest.TestCase): diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index c5222f0de..d9fe7a2fa 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -18,8 +18,8 @@ """ import numpy as np -import train_tokenizer -from input_pipeline import _input_pipeline_utils +from MaxText import train_tokenizer +from MaxText.input_pipeline import _input_pipeline_utils import unittest import pytest import tensorflow_datasets as tfds diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 87977079c..f59d80ab3 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -17,8 +17,7 @@ """ Tests for the common Max Utils """ import unittest import pytest -from train_compile import main as train_compile_main -from train import main as train_main +from MaxText.train_compile import main as train_compile_main class TrainCompile(unittest.TestCase): diff --git a/MaxText/tests/train_gpu_smoke_test.py b/MaxText/tests/train_gpu_smoke_test.py index d54831c78..cc92ed9f5 100644 --- a/MaxText/tests/train_gpu_smoke_test.py +++ b/MaxText/tests/train_gpu_smoke_test.py @@ -17,7 +17,7 @@ import os import unittest from absl.testing import absltest -from train import main as train_main +from MaxText.train import main as train_main class Train(unittest.TestCase): diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 3bc4c31e7..7eab6c08c 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -17,7 +17,7 @@ """Smoke test for int8""" import os import unittest -from train import main as train_main +from MaxText.train import main as train_main from absl.testing import absltest diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 74da43509..65d4e3f56 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -17,7 +17,7 @@ """ Smoke test """ import os import unittest -from train import main as train_main +from MaxText.train import main as train_main from absl.testing import absltest diff --git a/MaxText/tests/weight_dtypes_test.py b/MaxText/tests/weight_dtypes_test.py index 579e23310..320f040d7 100644 --- a/MaxText/tests/weight_dtypes_test.py +++ b/MaxText/tests/weight_dtypes_test.py @@ -20,10 +20,9 @@ import pyconfig -import optimizers -from layers import models -from layers import quantizations -import max_utils +from MaxText import optimizers +from MaxText.layers import models, quantizations +from MaxText import max_utils import jax from jax.sharding import Mesh import jax.numpy as jnp diff --git a/MaxText/train_compile.py b/MaxText/train_compile.py index a6dc4682f..276b2ccf1 100644 --- a/MaxText/train_compile.py +++ b/MaxText/train_compile.py @@ -28,7 +28,7 @@ from jax.sharding import Mesh from jax.experimental.serialize_executable import serialize from flax.linen import partitioning as nn_partitioning -import maxtext_utils +from MaxText import maxtext_utils import optimizers import max_utils import pyconfig