Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rename temp->temperature #280

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion audiocraft/models/audiogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
Expand Down
16 changes: 8 additions & 8 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def _sample_next_token(self,
cfg_conditions: CFGConditions,
unconditional_state: State,
use_sampling: bool = False,
temp: float = 1.0,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
Expand All @@ -325,7 +325,7 @@ def _sample_next_token(self,
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
should be twice the batch size, being the concatenation of the conditions + null conditions.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
temperature (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): classifier free guidance coefficient
Expand Down Expand Up @@ -363,9 +363,9 @@ def _sample_next_token(self,
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
logits = logits[..., -1] # [B x K x card]

# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
if use_sampling and temp > 0.0:
probs = torch.softmax(logits / temp, dim=-1)
# Apply softmax for sampling if temperature > 0. Else, do greedy sampling to avoid zero division error.
if use_sampling and temperature > 0.0:
probs = torch.softmax(logits / temperature, dim=-1)
if top_p > 0.0:
next_token = utils.sample_top_p(probs, p=top_p)
elif top_k > 0:
Expand All @@ -384,7 +384,7 @@ def generate(self,
num_samples: tp.Optional[int] = None,
max_gen_len: int = 256,
use_sampling: bool = True,
temp: float = 1.0,
temperature: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
Expand All @@ -401,7 +401,7 @@ def generate(self,
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
max_gen_len (int): Maximum generation length.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
temperature (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coeff (float, optional): Classifier-free guidance coefficient.
Expand Down Expand Up @@ -492,7 +492,7 @@ def generate(self,
assert not (curr_sequence == unknown_token).any()
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temperature, top_k, top_p,
cfg_coef=cfg_coef)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'temperature': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/solvers/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, cfg: omegaconf.DictConfig):
# easier access to sampling parameters
self.generation_params = {
'use_sampling': self.cfg.generate.lm.use_sampling,
'temp': self.cfg.generate.lm.temp,
'temperature': self.cfg.generate.lm.temperature,
'top_k': self.cfg.generate.lm.top_k,
'top_p': self.cfg.generate.lm.top_p,
}
Expand Down
2 changes: 1 addition & 1 deletion config/solver/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ generate:
sample_rate: null
lm:
use_sampling: false
temp: 1.0
temperature: 1.0
top_k: 0
top_p: 0.0
evaluate:
Expand Down
2 changes: 1 addition & 1 deletion config/solver/musicgen/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ generate:
remove_prompts: false
# generation params
use_sampling: false
temp: 1.0
temperature: 1.0
top_k: 0
top_p: 0.0
evaluate:
Expand Down
Loading