Skip to content

Commit

Permalink
more features!
Browse files Browse the repository at this point in the history
  • Loading branch information
yreddy31 committed Feb 25, 2021
1 parent 2ace26b commit b4654a6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
except ImportError:
from distutils.core import setup

VERSION = '0.313'
VERSION = '0.314'
setup(
name = 'torch_snippets', # How you named your package folder (MyLib)
packages = ['torch_snippets'], # Chose the same as "name"
Expand Down
31 changes: 25 additions & 6 deletions torch_snippets/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,13 @@ def puttext(ax, string, org, size=15, color=(255,0,0), thickness=2):
path_effects.Normal()])

def dumpdill(obj, fpath, silent=False):
start = time.time()
os.makedirs(parent(fpath), exist_ok=True)
with open(fpath, 'wb') as f:
dill.dump(obj, f)
if not silent: logger.info('Dumped object @ {}'.format(fpath))
if not silent:
fsize = os.path.getsize(fpath) >> 20
logger.info(f'Dumped object of size `~{fsize} MB` @ "{fpath}" in {time.time()-start:.2f} seconds')

def loaddill(fpath):
with open(fpath, 'rb') as f:
Expand Down Expand Up @@ -653,15 +656,31 @@ def to_relative(input, shape):
bbs = bbfy(input)
return [bb.relative((h,w)) for bb in bbs]

def compute_eps(eps):
if isinstance(eps, tuple):
if len(eps) == 4:
epsx, epsy, epsX, epsY = eps
else:
epsx, epsy = eps
epsx, epsy, epsX, epsY = epsx/2, epsy/2, epsx/2, epsy/2
else:
epsx, epsy, epsX, epsY = eps/2, eps/2, eps/2, eps/2
return epsx, epsy, epsX, epsY

def enlarge_bbs(bbs, eps=0.2):
"enlarge all `bbs` by `eps` fraction (or eps*100 percent)"
epsx, epsy = eps if isinstance(eps, tuple) else (eps, eps)
"enlarge all `bbs` by `eps` fraction (i.e., eps*100 percent)"
bbs = bbfy(bbs)
epsx, epsy, epsX, epsY = compute_eps(eps)
bbs = bbfy(bbs)
shs = [(bb.h,bb.w) for bb in bbs]
return [BB(x-(w*eps/2), y-(h*eps/2), X+(w*eps/2), Y+(h*eps/2))\
return [BB(x-(w*epsx), y-(h*epsy), X+(w*epsX), Y+(h*epsY))\
for (x,y,X,Y),(h,w) in zip(bbs, shs)]

def shrink_bbs(bbs, eps=0.2):
"shrink all `bbs` by `eps` fraction (or eps*100 percent)"
"shrink all `bbs` by `eps` fraction (i.e., eps*100 percent)"
bbs = bbfy(bbs)
epsx, epsy, epsX, epsY = compute_eps(eps)
bbs = bbfy(bbs)
shs = [(bb.h,bb.w) for bb in bbs]
return [BB(x+(w*eps/2), y+(h*eps/2), X-(w*eps/2), Y-(h*eps/2))\
return [BB(x+(w*epsx), y+(h*epsy), X-(w*epsX), Y-(h*epsY))\
for (x,y,X,Y),(h,w) in zip(bbs, shs)]
3 changes: 2 additions & 1 deletion torch_snippets/torch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def report_metrics(self, pos, **report):
print(f'\r{log}{current_iteration}{info(report, self.precision)}{elapsed}', end=end)

try:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import ProgressBarBase
class LightningReport(ProgressBarBase):
def __init__(self, epochs, print_every=None, print_total=None, precision=4, old_report=None):
Expand Down Expand Up @@ -226,7 +227,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
def __getattr__(self, attr, **kwargs):
return getattr(self.report, attr, **kwargs)

__all__ += ['LightningReport']
__all__ += ['LightningReport', 'pl']
except:
logger.warning('Not importing Lightning Report')

Expand Down

0 comments on commit b4654a6

Please sign in to comment.