Skip to content

Commit

Permalink
fix: seed should be 64 bit
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Sep 3, 2023
1 parent e5a7aec commit 4584286
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct Option {
int sample_steps = 20;
float strength = 0.75f;
RNGType rng_type = STD_DEFAULT_RNG;
int seed = 42;
int64_t seed = 42;
bool verbose = false;

void print() {
Expand All @@ -106,7 +106,7 @@ struct Option {
printf(" sample_steps: %d\n", sample_steps);
printf(" strength: %.2f\n", strength);
printf(" rng: %s\n", rng_type_to_str[rng_type]);
printf(" seed: %d\n", seed);
printf(" seed: %ld\n", seed);
}
};

Expand Down Expand Up @@ -233,7 +233,7 @@ void parse_args(int argc, const char* argv[], Option* opt) {
invalid_arg = true;
break;
}
opt->seed = std::stoi(argv[i]);
opt->seed = std::stoll(argv[i]);
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv);
exit(0);
Expand Down
4 changes: 2 additions & 2 deletions rng.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class RNG {
public:
virtual void manual_seed(uint32_t seed) = 0;
virtual void manual_seed(uint64_t seed) = 0;
virtual std::vector<float> randn(uint32_t n) = 0;
};

Expand All @@ -15,7 +15,7 @@ class STDDefaultRNG : public RNG {
std::default_random_engine generator;

public:
void manual_seed(uint32_t seed) {
void manual_seed(uint64_t seed) {
generator.seed(seed);
}

Expand Down
2 changes: 1 addition & 1 deletion rng_philox.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class PhiloxRNG : public RNG {
this->offset = 0;
}

void manual_seed(uint32_t seed) {
void manual_seed(uint64_t seed) {
this->seed = seed;
this->offset = 0;
}
Expand Down
4 changes: 2 additions & 2 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3823,7 +3823,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
int height,
SampleMethod sample_method,
int sample_steps,
int seed) {
int64_t seed) {
std::vector<uint8_t> result;
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10M
Expand Down Expand Up @@ -3911,7 +3911,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
SampleMethod sample_method,
int sample_steps,
float strength,
int seed) {
int64_t seed) {
std::vector<uint8_t> result;
if (init_img_vec.size() != width * height * 3) {
return result;
Expand Down
4 changes: 2 additions & 2 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class StableDiffusion {
int height,
SampleMethod sample_method,
int sample_steps,
int seed);
int64_t seed);
std::vector<uint8_t> img2img(
const std::vector<uint8_t>& init_img,
const std::string& prompt,
Expand All @@ -51,7 +51,7 @@ class StableDiffusion {
SampleMethod sample_method,
int sample_steps,
float strength,
int seed);
int64_t seed);
};

void set_sd_log_level(SDLogLevel level);
Expand Down

0 comments on commit 4584286

Please sign in to comment.