-
Notifications
You must be signed in to change notification settings - Fork 0
/
encode_decode.py
90 lines (75 loc) · 2.89 KB
/
encode_decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
encode/decode a single image using spiht
saves an image of the decoded image
"""
import math
from argparse import ArgumentParser
from typing import List, Optional
from PIL import Image
import numpy as np
from spiht.spiht_wrapper import SpihtSettings, get_slices_and_h_w
import time
from spiht.utils import imload
from spiht import encode_image,decode_image
parser =ArgumentParser()
parser.add_argument('image_filename')
parser.add_argument('--bpp', help='bits per pixel', type=float, default=0.1)
parser.add_argument('--quantization_scale', default=255.0, type=float)
parser.add_argument('--level', help='wavedec2 level. default is set so that the highest DWT level has a width and height of 4.', default=None, type=int)
parser.add_argument('--wavelet', help='wavedec2 wavelet', default='bior2.2', type=str)
parser.add_argument('--mode', help='wavedec2 mode', default='reflect', type=str)
parser.add_argument('--color_model', default="IPT", type=str)
parser.add_argument('--per_channel_quant_scales', default="1., 0.2, 0.2", type=str)
parser.add_argument('--out', help='save reconstructed image to this file path', type=str, default='reconstructed.png')
def main(args):
im = imload(args.image_filename)
c,h,w = im.shape
if args.level is None:
level = min(
math.log2(h/8),
math.log2(w/8)
)
level = math.floor(level)
else:
level=args.level
pixels = h*w
max_bits = round(args.bpp * pixels)
per_channel_quant_scales = list(float(x) for x in args.per_channel_quant_scales.split(","))
spiht_settings = SpihtSettings(
quantization_scale=args.quantization_scale,
mode=args.mode,
wavelet=args.wavelet,
color_model=args.color_model,
per_channel_quant_scales=per_channel_quant_scales,
)
print(f"Starting encoding of image {c} {h} {w}")
st = time.time()
encoded = encode_image(
im,
spiht_settings,
level,
max_bits,
)
et = time.time()
print(f"Encoding done in {et-st:.3f}s. Image encoded to {len(encoded.encoded_bytes) / 1024:.2f}kb")
print(f" levels: {encoded.level}")
print(f" max n: {encoded.max_n}")
slices, enc_h, enc_w = get_slices_and_h_w(h,w,spiht_settings,encoded.level)
ll_h, ll_w = slices[0][1].stop, slices[0][2].stop
print(f"ll_h ll_w: {ll_h, ll_w}")
st = time.time()
dec_im = decode_image(encoded, spiht_settings)
et = time.time()
print(f"Decoding done in {et-st:.3f}s. L2 distance: {((im-dec_im)**2).mean():.5f}")
if c==1:
dec_im = dec_im[0]
else:
dec_im = np.moveaxis(dec_im, 0, -1)
dec_im = dec_im.clip(0.0, 1.0)
dec_im = (dec_im * 255).astype(np.uint8)
dec_im_pil = Image.fromarray(dec_im)
dec_im_pil.save(args.out)
print("Saved to ", args.out)
if __name__ == "__main__":
args=parser.parse_args()
main(args)