Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do we give input to the model, and where are the processed images stored? #7

Open
harbarex opened this issue May 6, 2019 · 3 comments

Comments

@harbarex
Copy link

harbarex commented May 6, 2019

I have trained the model, but now I need to test it.
I took the demo.py as inspiration for the new demo, and am trying to give my custom caption as input. However I do not know how to do so.

# load the model for the demo
gen = th.nn.DataParallel(pg.Generator(depth=9))
gen.load_state_dict(th.load("GAN_GEN_SHADOW_8.pth", map_location=str(device)))

How do I change the above code for making my trained model work?

@adorsho
Copy link

adorsho commented Oct 21, 2019

hello @MeteoRex11 did you solve this problem ? I am facing the same problem, please help me

@harbarex
Copy link
Author

You would need to store text in torch file format, with text embeddings already in them. This is a very tedious process, hence I did not make a script for it

@shiv6891
Copy link

shiv6891 commented Jul 31, 2020

I got it working after few cpu specific modifications for my laptop, using below:

Referred from: An IPYNB
Note: Depending upon PT version, few minor changes might be required

After importing your condition augmentor, text encoder and gan generator...

import pickle
file_name = 'path/to/your/.pkl'
obj = None
# obj = pickle.loads(open(file_name, 'rb').read())
with open(file_name, "rb") as pick:
    obj = pickle.loads(pick.read())

max_caption_len = 100
in_str = input('Enter your caption : ')
in_str_tok = in_str.replace('_', ' ').split()
in_ind_list = [obj['rev_vocab'][in_str.strip()] for in_str in in_str_tok if in_str.strip() in obj['rev_vocab']]
caption = in_ind_list
full_str = []
for ind in caption:
    full_str.append(obj['vocab'][ind])
str_proc = filter('<pad>'.__ne__, full_str)
    
if len(caption) < max_caption_len:
    while len(caption) != max_caption_len:
        caption.append(obj['rev_vocab']["<pad>"])

elif len(caption) > max_caption_len:
    caption = caption[: max_caption_len]

fixed_captions = th.tensor([caption], dtype=th.long)
print("Text initialized!")

fixed_embeddings = text_encoder(fixed_captions)
fixed_embeddings = th.from_numpy(fixed_embeddings.detach().cpu().data.numpy()).to(device)
fixed_c_not_hats, mus, _ = condition_augmenter(fixed_embeddings)
fixed_noise = th.zeros(len(fixed_captions), c_pro_gan.latent_size - fixed_c_not_hats.shape[-1]).to(device)
fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)
print("Gan input prepared")

And then...

import matplotlib.pyplot as plt
%matplotlib inline
create_grid(
    samples=c_pro_gan.gen(
        fixed_gan_input,
        4,
        1.0
    ),
    scale_factor=1,
    img_file='output.png')

img = plt.imread('output.png')
plt.figure()
plt.imshow(img)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants