Skip to content

Commit

Permalink
Add support for InstructPix2Pix (fixes #40; PR #60)
Browse files Browse the repository at this point in the history
  • Loading branch information
nityanandmathur authored Apr 5, 2024
1 parent 09564e0 commit c30493e
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion daam/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,12 @@ def _unravel_attn(self, x):
with auto_autocast(dtype=torch.float32):
for map_ in x:
map_ = map_.view(map_.size(0), h, w)
map_ = map_[map_.size(0) // 2:] # Filter out unconditional
# For Instruct Pix2Pix, divide the map into three parts: text condition, image condition and unconditional,
# and only keep the text condition part, which is first of the three parts(as per diffusers implementation).
if map_.size(0) == 24:
map_ = map_[:((map_.size(0) // 3)+1)] # Filter out unconditional and image condition
else:
map_ = map_[map_.size(0) // 2:] # # Filter out unconditional
maps.append(map_)

maps = torch.stack(maps, 0) # shape: (tokens, heads, height, width)
Expand Down

0 comments on commit c30493e

Please sign in to comment.