Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get all tests to pass locally with no special configuration #1108

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MaxText/convert_gemma2_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import orbax

import checkpointing
from train import save_checkpoint
from MaxText.train import save_checkpoint

Params = dict[str, Any]

Expand Down
2 changes: 1 addition & 1 deletion MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import orbax

import checkpointing
from train import save_checkpoint
from MaxText.train import save_checkpoint

Params = dict[str, Any]

Expand Down
2 changes: 1 addition & 1 deletion MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import gc
import max_logging
from psutil import Process
from train import save_checkpoint
from MaxText.train import save_checkpoint
import argparse


Expand Down
2 changes: 1 addition & 1 deletion MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from jax import random
from typing import Sequence
from layers import models, quantizations
from train import save_checkpoint
from MaxText.train import save_checkpoint

Transformer = models.Transformer

Expand Down
2 changes: 1 addition & 1 deletion MaxText/inference_microbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from jetstream.engine import token_utils

import max_utils
from MaxText import max_utils
import maxengine
import maxtext_utils
import profiler
Expand Down
8 changes: 4 additions & 4 deletions MaxText/input_pipeline/_tfds_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import tensorflow_datasets as tfds
import jax

import multihost_dataloading
import tokenizer
import sequence_packing
from input_pipeline import _input_pipeline_utils
from MaxText import multihost_dataloading
from MaxText import tokenizer
from MaxText import sequence_packing
from MaxText.input_pipeline import _input_pipeline_utils

AUTOTUNE = tf.data.experimental.AUTOTUNE

Expand Down
2 changes: 1 addition & 1 deletion MaxText/kernels/ragged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
import common_types
from MaxText import common_types

from jax.experimental import shard_map

Expand Down
14 changes: 7 additions & 7 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
import jax.numpy as jnp
import common_types
from kernels.ragged_attention import ragged_gqa
from kernels.ragged_attention import ragged_mha
from layers import embeddings
from layers import initializers
from layers import linears
from layers import quantizations
from MaxText import common_types
from MaxText.kernels.ragged_attention import ragged_gqa
from MaxText.kernels.ragged_attention import ragged_mha
from MaxText.layers import embeddings
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import quantizations


# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax
from jax import lax
import jax.numpy as jnp
from layers import initializers
from MaxText.layers import initializers

Config = Any
Array = jnp.ndarray
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from flax import linen as nn
import jax
import common_types
from MaxText import common_types

Array = common_types.Array
DType = common_types.DType
Expand Down
14 changes: 7 additions & 7 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
import jax
from jax import lax
import jax.numpy as jnp
import common_types
from layers import initializers
from layers import normalizations
from layers import quantizations
from MaxText import common_types
from MaxText.layers import initializers
from MaxText.layers import normalizations
from MaxText.layers import quantizations
import numpy as np
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
import math
import max_logging
import max_utils
from MaxText import max_logging
from MaxText import max_utils
from aqt.jax.v2 import aqt_tensor
from kernels import megablox as mblx
from MaxText.kernels import megablox as mblx


Array = common_types.Array
Expand Down
16 changes: 8 additions & 8 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
# from jax.experimental.pallas.ops.tpu import flash_attention
from layers import attentions
from layers import embeddings
from layers import linears
from layers import normalizations
from layers import models
from layers import quantizations

import common_types
from MaxText.layers import attentions
from MaxText.layers import embeddings
from MaxText.layers import linears
from MaxText.layers import normalizations
from MaxText.layers import models
from MaxText.layers import quantizations

from MaxText import common_types
from typing import Optional

Array = common_types.Array
Expand Down
12 changes: 6 additions & 6 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
import common_types
from layers import attentions
from layers import embeddings
from layers import linears
from layers import normalizations, quantizations
from layers import pipeline
from MaxText import common_types
from MaxText.layers import attentions
from MaxText.layers import embeddings
from MaxText.layers import linears
from MaxText.layers import normalizations, quantizations
from MaxText.layers import pipeline

Array = common_types.Array
Config = common_types.Config
Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from flax import linen as nn
from jax import lax
import jax.numpy as jnp
from layers import initializers
from MaxText.layers import initializers

Initializer = initializers.Initializer

Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax import numpy as jnp
from flax.core import meta
from flax import linen as nn
import common_types
from MaxText import common_types
import functools
from typing import Any

Expand Down
2 changes: 1 addition & 1 deletion MaxText/layers/quantizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from aqt.jax.v2.flax import aqt_flax
from aqt.jax.v2 import tiled_dot_general
from aqt.jax.v2 import calibration
import common_types
from MaxText import common_types
from dataclasses import dataclass
import flax.linen as nn
import jax
Expand Down
2 changes: 1 addition & 1 deletion MaxText/llama_ckpt_conversion_inference_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import jax
from flax.training import train_state
import max_logging
from train import save_checkpoint
from MaxText.train import save_checkpoint
import torch
import sys
import os
Expand Down
2 changes: 1 addition & 1 deletion MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from tqdm import tqdm

import max_logging
from train import save_checkpoint
from MaxText.train import save_checkpoint
import checkpointing


Expand Down
3 changes: 1 addition & 2 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
import checkpointing
import common_types
from MaxText import checkpointing, common_types
import functools
import time
import optax
Expand Down
4 changes: 2 additions & 2 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from flax.linen import partitioning as nn_partitioning
from flax import struct

from layers import models, quantizations
from MaxText.layers import models, quantizations

import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P

import common_types
from MaxText import common_types
from jetstream.core import config_lib
from jetstream.engine import engine_api
from jetstream.engine import tokenizer_pb2
Expand Down
2 changes: 1 addition & 1 deletion MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import jax
import optax
import max_utils
from MaxText import max_utils
from jax.sharding import PartitionSpec as P
from jax.experimental.serialize_executable import deserialize_and_load

Expand Down
4 changes: 2 additions & 2 deletions MaxText/standalone_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
from jax import numpy as jnp
import numpy as np

import checkpointing
from MaxText import checkpointing
import max_utils
import max_logging
import pyconfig
from train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint
from MaxText.train import setup_mesh_and_model, get_first_step, validate_train_config, save_checkpoint

from layers import models

Expand Down
2 changes: 1 addition & 1 deletion MaxText/standalone_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

import pyconfig
from train import validate_train_config, get_first_step, load_next_batch, setup_train_loop
from MaxText.train import validate_train_config, get_first_step, load_next_batch, setup_train_loop


def data_load_loop(config, state=None):
Expand Down
8 changes: 3 additions & 5 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@
import sys
import unittest

import common_types

from MaxText import common_types, max_utils
from flax.core import freeze
import jax
import jax.numpy as jnp
import max_utils
import numpy as np

import pytest

import pyconfig

from layers import attentions
from MaxText.layers import attentions

Mesh = jax.sharding.Mesh
Attention = attentions.Attention
Expand Down
6 changes: 2 additions & 4 deletions MaxText/tests/gpt3_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import sys
import jax
import unittest
import max_utils
from MaxText import max_utils
from jax.sharding import Mesh
from layers import models
from layers import embeddings
from layers import quantizations
from MaxText.layers import models, embeddings, quantizations

import jax.numpy as jnp

Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/gradient_accumulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import string
import random
from train import main as train_main
from MaxText.train import main as train_main


def generate_random_string(length=10):
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import unittest

import pyconfig
from input_pipeline import _grain_data_processing
from input_pipeline import input_pipeline_interface
from MaxText.input_pipeline import _grain_data_processing
from MaxText.input_pipeline import input_pipeline_interface


class GrainDataProcessingTest(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions MaxText/tests/hf_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import unittest

import pyconfig
from input_pipeline import _hf_data_processing
from input_pipeline import input_pipeline_interface
from MaxText.input_pipeline import _hf_data_processing
from MaxText.input_pipeline import input_pipeline_interface


class HfDataProcessingTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/inference_microbenchmark_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest
import unittest
from absl.testing import absltest
from inference_microbenchmark import main as inference_microbenchmark_main
from MaxText.inference_microbenchmark import main as inference_microbenchmark_main


class Inference_Microbenchmark(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/kernels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import unittest
import jax
import jax.numpy as jnp
from kernels.ragged_attention import ragged_mqa, reference_mqa, ragged_mha, reference_mha, ragged_gqa, reference_gqa
from MaxText.kernels.ragged_attention import ragged_mqa, reference_mqa, ragged_mha, reference_mha, ragged_gqa, reference_gqa


class RaggedAttentionTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
import jax.numpy as jnp
from typing import Tuple
from layers.llama2 import embeddings
from MaxText.layers.llama2 import embeddings
import numpy as np


Expand Down
6 changes: 3 additions & 3 deletions MaxText/tests/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

""" Tests for the common Max Utils """
import jax
import max_utils
from MaxText import max_utils
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
Expand All @@ -25,8 +25,8 @@
import optax
import pyconfig
import unittest
from layers import models
from layers import quantizations
from MaxText.layers import models
from MaxText.layers import quantizations

Transformer = models.Transformer

Expand Down
Loading
Loading