-
Notifications
You must be signed in to change notification settings - Fork 0
/
multiplication.cpp
355 lines (282 loc) · 9.04 KB
/
multiplication.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
#define _CRT_RAND_S // rand_s
#include <stdlib.h> // rand_s
#include <iostream> // cout
#include <ctime> // clock
#include <cstdlib>
using namespace std;
#define BASE UINT_MAX // base of the numerical notation. UINT_MAX == 4,294,967,295
#define MIN_LENGTH_FOR_KARATSUBA 4 // if the number is shorter, it won't be multiplied using Karatsuba method
#define N_TESTS 3
#define NUMBERS_LENGTH 10000
typedef unsigned int digit; // type of one digit in the chosen numerical notation to accomodate tha max value of BASE
typedef unsigned long long int double_digit; // type to accomodate tha max value of BASE*BASE
/* Struct to store long numbers in the chosen numerical notation */
struct LongNumber {
digit* val; // array to store the "digits" where [0] element stores the right-most (low-order) digit of the number
size_t len; // number of "digits" (size of the array)
// Constructors
LongNumber() {
len = 0;
val = nullptr;
}
LongNumber(size_t length) {
len = length;
val = new digit[len];
memset(val, 0, sizeof(val) * len); // fill number with zeros
}
// Copy-constructor
LongNumber(const LongNumber &obj) {
this->len = obj.len;
this->val = new digit[len];
memcpy(this->val, obj.val, sizeof(val) * len);
}
// Move-constructor
LongNumber(LongNumber &&obj) {
this->len = obj.len;
this->val = obj.val;
obj.len = 0;
obj.val = nullptr;
}
LongNumber& operator = (LongNumber& obj) {
if (val) delete[] val;
len = obj.len;
val = new digit[len];
memcpy(val, obj.val, sizeof(val) * len);
return *this;
}
~LongNumber() {
if (val) {
delete[] val;
val = nullptr;
}
}
friend LongNumber naive_multiply(LongNumber x, LongNumber y);
friend LongNumber karatsuba_multiply(LongNumber x, LongNumber y);
friend void print(LongNumber number);
// Increase the number of digits in the long number by adding heading zeros
void increaseDigits(size_t new_length) {
if (new_length > len) {
digit* temp = new digit[new_length]; // to store the new number
memcpy(temp, val, len * sizeof(val)); // copy the value
memset(temp + len, 0, (new_length - len) * sizeof(val)); // add zeros
digit* old_value = val;
val = temp;
delete[] old_value;
len = new_length;
}
}
// Normalize the number by deleting heading zeros
void normalizeDigits() {
size_t i = len - 1;
while (val[i] == 0 && i > 0) --i;
len = i + 1;
}
// Multiplies the number by BASE^n:
// 1. Increases the length by n
// 2. "Shifts" value by n digits
// 3. Fills n empty digits (val[0...n-1]) with zeros
// For example, if BASE == 10 && x.val == {3,2,1}, then x.shift(3) makes x.val == {0,0,0,3,2,1}
void shift(size_t n) {
digit* temp = new digit[len + n];
memset(temp, 0, n * sizeof(val));
memcpy(temp + n, val, len * sizeof(val));
digit* old_value = val;
val = temp;
delete[] old_value;
len += n;
}
// Operator + to sum two long numbers
// Returns (this + A)
LongNumber operator + (LongNumber A) {
// Bring two long number (this and A) to the same length
if (len > A.len) {
A.increaseDigits(len);
}
else if (len < A.len) {
this->increaseDigits(A.len);
}
LongNumber R(len); // to store the result
digit carryover = 0;
double_digit temp;
for (size_t i = 0; i < len; ++i) {
temp = val[i] + A.val[i] + carryover;
R.val[i] = temp % BASE;
carryover = temp / BASE;
}
if (carryover != 0) {
R.increaseDigits(len + 1);
R.val[len] = carryover;
}
return R;
}
// Operator - to subtract two long numbers: A from this
// Returns (this - A)
LongNumber operator - (LongNumber A) {
// Bring two long number (this and A) to the same length
if (len > A.len) {
A.increaseDigits(len);
}
else if (len < A.len) {
this->increaseDigits(A.len);
}
LongNumber R(len); // to store the result
digit carryover = 0;
double_digit temp;
for (size_t i = 0; i < len; ++i) {
temp = val[i] - A.val[i] - carryover;
if (temp < 0) {
R.val[i] = temp + BASE;
carryover = 1;
}
else {
R.val[i] = temp;
carryover = 0;
}
}
if (carryover != 0) {
R.increaseDigits(len + 1);
R.val[len] = carryover;
}
return R;
}
// Operator * to multiply two long numbers: this and A
LongNumber operator * (LongNumber A) {
this->normalizeDigits();
A.normalizeDigits();
if ((len == 1 && val[0] == 0) || (A.len == 1 && A.val[0] == 0)) {
LongNumber result(1);
return result;
}
if (len == 1 && val[0] == 1) return A;
if (A.len == 1 && A.val[0] == 1) return *this;
if (len < MIN_LENGTH_FOR_KARATSUBA || A.len < MIN_LENGTH_FOR_KARATSUBA) {
return (naive_multiply(*this, A));
}
else {
return (karatsuba_multiply(*this, A));
}
}
};
/* Naive multiplication of two long numbers */
LongNumber naive_multiply(LongNumber x, LongNumber y) {
LongNumber interim(x.len + y.len), result(x.len + y.len);
digit carryover;
double_digit temp;
for (size_t i = 0; i < y.len; ++i) {
carryover = 0;
memset(interim.val, 0, sizeof(interim.val)*interim.len); // initialize interim as 0
// multiply y[i] and x[0...x.len] (one digit of y to all digits of x)
// and save the result in interim
for (size_t j = 0; j < x.len; ++j) {
temp = y.val[i] * x.val[j] + carryover;
interim.val[i + j] = temp % BASE;
carryover = temp / BASE;
}
interim.val[i + x.len] = carryover;
result = result + interim;
}
result.normalizeDigits();
return result;
}
/* Multiplication of two long numbers using Karatsuba algorithm */
LongNumber karatsuba_multiply(LongNumber a, LongNumber b) {
// Bring two long number to the same length
if (a.len > b.len) {
b.increaseDigits(a.len);
}
else if (a.len < b.len) {
a.increaseDigits(b.len);
}
// Make length of a and b even
if (a.len % 2 == 1) {
a.increaseDigits(a.len + 1);
b.increaseDigits(b.len + 1);
}
// Split a and b both in two parts: a0 and a1, b0 and b1 and copy values
size_t m = a.len / 2;
LongNumber a0(m), a1(m), b0(m), b1(m);
memcpy(a0.val, a.val, sizeof(a.val) * m);
memcpy(a1.val, a.val + m, sizeof(a.val) * m);
memcpy(b0.val, b.val, sizeof(b.val) * m);
memcpy(b1.val, b.val + m, sizeof(b.val) * m);
LongNumber mul_a0_b0 = a0 * b0;
LongNumber mul_a1_b1 = a1 * b1;
LongNumber middle_part = (a0 + a1) * (b0 + b1) - mul_a0_b0 - mul_a1_b1;
middle_part.shift(m); // = middle_part * BASE^m
LongNumber last_part = mul_a1_b1;
last_part.shift(2 * m); // = last_part * BASE^2m
LongNumber result = mul_a0_b0 + middle_part + last_part;
return result;
}
// prints LongNumber
void print(LongNumber number) {
digit temp;
bool metSignificant = false;
// go through all digits of the number
for (int i = number.len - 1; i >= 0; --i) {
if (!metSignificant) { // if it's a heading digit of the number
if (i == 0 || number.val[i] != 0) {
cout << number.val[i];
metSignificant = true;
}
}
else { // if it's not a heading digit of the number
cout << ",";
// print heading zeros in the current digit, if any
if (number.val[i] == 0) {
temp = (BASE - 1) / 10;
}
else {
temp = (BASE - 1) / number.val[i] / 10;
}
while (temp > 0) {
cout << "0";
temp /= 10;
}
cout << number.val[i];
}
}
cout << " (length " << number.len << ")\n";
}
// Returns a LongNumber where all digits are filled with random numbers not greater than BASE
LongNumber generateRandom(size_t length) {
LongNumber result(length);
digit random_number;
for (size_t i = 0; i < length; ++i) {
rand_s(&random_number);
result.val[i] = random_number % BASE;
}
return result;
}
int main(int argc, char** argv) {
LongNumber numberA(NUMBERS_LENGTH), numberB(NUMBERS_LENGTH);
numberA = generateRandom(numberA.len);
numberB = generateRandom(numberB.len);
cout << "BASE = " << BASE << "\n";
cout << "N_TESTS = " << N_TESTS << "\n";
cout << "Length A = " << numberA.len << "; length B = " << numberB.len << "\n";
//print(numberA);
//cout << " * \n";
//print(numberB);
//cout << " = \n";
LongNumber temp;
cout << "Method 1. Naive multiplication.\n";
clock_t start = clock();
for (size_t i = 0; i < N_TESTS; ++i)
naive_multiply(numberA, numberB);
double naive_elapsed = (double)(clock() - start) / CLOCKS_PER_SEC;
cout << "Elapsed time: " << naive_elapsed << "\n";
cout << "Method 2. Karatsuba multiplication.\n";
start = clock();
for (size_t i = 0; i < N_TESTS; ++i)
karatsuba_multiply(numberA, numberB);
double karats_elapsed = (double)(clock() - start) / CLOCKS_PER_SEC;
cout << "Elapsed time: " << karats_elapsed << "\n";
if (naive_elapsed > karats_elapsed) {
cout << "Karatsuba algorithm is " << (float)naive_elapsed / karats_elapsed << " times faster\n";
}
else {
cout << "Naive algorithm is " << (float)karats_elapsed / naive_elapsed << " times faster\n";
}
return 0;
}