Skip to content

Commit

Permalink
Fix bug: progress first index starts from 1 instead of 0 (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud authored Nov 26, 2024
1 parent c360d1e commit a519eb3
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 33 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def store_input_hook(_, args, kwargs):
quantizers = {}

layer_count = len(layers)
layer_pb = ProgressBar(layer_count)
layer_pb = ProgressBar(range(layer_count))
gpu_memorys = []
cpu_memorys = []
durations = []
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def convert_to_bitblas(model, model_quantlinear, quant_config: QuantizeConfig, s

# Note that due to tvm compilation of per layer modules shapes, the first layer loop is
# relatively much slower if caching is not available. estimate time remaining is highly inaccurate
for name, module in ProgressBar(model.named_modules(), desc=message):
for name, module in ProgressBar(model.named_modules(), desc=message, total=len(list(model.named_modules()))):
if not isinstance(module, model_quantlinear):
continue

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def pack_model(
max_workers = 1

with ThreadPoolExecutor(max_workers=max_workers) as executor:
with ProgressBar(len(names)) as pbar:
with ProgressBar(total=len(names)) as pbar:
def wrapper(name):
pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar)

Expand Down
118 changes: 88 additions & 30 deletions gptqmodel/utils/progress.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
import datetime
import sys

import time
from warnings import warn

class ProgressBarWarning(Warning):
"""base class for all tqdm warnings.
Used for non-external-code-breaking errors, such as garbled printing.
"""
def __init__(self, msg, fp_write=None, *a, **k):
if fp_write is not None:
fp_write("\n" + self.__class__.__name__ + ": " + str(msg).rstrip() + '\n')
else:
super().__init__(msg, *a, **k)

class ProgressBar:
def __init__(self, data, prefix='', length=40, fill='█', desc=""):
self.list = []
if isinstance(data, range):
self.total = len(data)
elif isinstance(data, list):
self.list = data
self.total = len(data)
elif isinstance(data, int):
self.total= data
def __init__(self, iterable=None, total=None, prefix='', bar_length=40, fill='█', desc=""):
if total is None and iterable is not None:
try:
total = len(iterable)
except (TypeError, AttributeError):
total = None
if total == float("inf"):
# Infinite iterations, behave same as unknown
total = None

self.iterable = iterable
self.total = total

self.prefix = prefix
self.length = length
self.bar_length = bar_length
self.fill = fill
self.description = desc
self.current = 0
Expand All @@ -27,39 +41,83 @@ def set_description(self, description):
def progress(self, iteration = None):
if not iteration:
iteration = self.current
percent = ("{0:.1f}").format(100 * (iteration / float(self.total)))
filled_length = int(self.length * iteration // self.total)
bar = self.fill * filled_length + '-' * (self.length - filled_length)
self.log(bar, f"{self.calc_time(iteration)} [{iteration}/{self.total}] {percent}%")
percent = ("{0:.1f}").format(100 * (iteration / float(len(self))))
filled_length = int(self.bar_length * iteration // len(self))
bar = self.fill * filled_length + '-' * (self.bar_length - filled_length)
self.log(bar, f"{self.calc_time(iteration)} [{iteration}/{len(self)}] {percent}%")

def calc_time(self, iteration):
used_time = int(time.time() - self.time)
formatted_time = str(datetime.timedelta(seconds=used_time))
remaining = str(datetime.timedelta(seconds=int((used_time / max(iteration, 1)) * self.total)))
remaining = str(datetime.timedelta(seconds=int((used_time / max(iteration, 1)) * len(self))))
return f"{formatted_time} / {remaining}"

def log(self, bar, log):
print(f'\r{self.prefix} {self.description} |{bar}| {log}', end='', flush=True)

# fix TypeError: 'ProgressBar' object does not support the context manager protocol
def __bool__(self):
if self.total is not None:
return self.total > 0
if self.iterable is None:
raise TypeError('bool() undefined when iterable == total == None')
return bool(self.iterable)

def __len__(self):
return (
self.total if self.iterable is None
else self.iterable.shape[0] if hasattr(self.iterable, "shape")
else len(self.iterable) if hasattr(self.iterable, "__len__")
else self.iterable.__length_hint__() if hasattr(self.iterable, "__length_hint__")
else getattr(self, "total", None))

def __reversed__(self):
try:
orig = self.iterable
except AttributeError:
raise TypeError("'tqdm' object is not reversible")
else:
self.iterable = reversed(self.iterable)
return self.__iter__()
finally:
self.iterable = orig

def __contains__(self, item):
contains = getattr(self.iterable, '__contains__', None)
return contains(item) if contains is not None else item in self.__iter__()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.log(f"{'-' * self.length}", "100.0%")
def __exit__(self, exc_type, exc_value, traceback):
try:
self.close()
except AttributeError:
# maybe eager thread cleanup upon external error
if (exc_type, exc_value, traceback) == (None, None, None):
raise
warn("AttributeError ignored", ProgressBarWarning, stacklevel=2)

def __del__(self):
self.close()

@property
def _comparable(self):
return abs(getattr(self, "pos", 1 << 31))

def __hash__(self):
return id(self)

def __iter__(self):
return self
"""Backward-compatibility to use: for x in tqdm(iterable)"""

def __next__(self):
if self.list:
self.current += 1
return self.list.pop(0)
if self.current < self.total - 1:
self.current += 1
self.progress(self.current)
return self.current
else:
raise StopIteration
# Inlining instance variables as locals (speed optimisation)
iterable = self.iterable

for obj in iterable:
yield obj
return

def close(self):
self.log(f"{'-' * self.bar_length}", "100.0%")


0 comments on commit a519eb3

Please sign in to comment.