When doing pure integer calculation, we sometimes face a problem where we have to calculate a * b / c, where a, b, c are signed integers of some size, say 32-bit integers. Considering that a * b might be larger than the maximum value of 32-bit integers, the common way to prevent possible overflowing is to cast them to larger integers, in this case 64-bit integers, and do calculation there.
But what if we don’t want to cast them? (For whatever reason – maybe the integers themselves are already large so you can’t cast them to larger integers without writing your own large integer class; or maybe you just want some intellectual challenge.)
The Algorithm
First of all we can assume a <= b < c: obviously a and b are interchangeable so we can have a <= b; and if a >= c or b >=c, let’s say if a >= c, then we have:
a = kc + a’, k >= 1
and so
ab / c = (kc + a’)b / c = kb + a’b / c.
If kb already overflows then there is no way we can prevent overflowing with any algorithm anyway – just use larger data types. Otherwise, we can calculate kb separately since it’s just integer multiplication, and reduce the calculation to a’b / c where a < c. The same applies to b, so we can simplify the problem down to a <= b < c.
It follows that the final result will not overflow because ab < c^2, and so ab/c < c.
Now that a < c, we can calculate c / a. Let c = ma + n where m is the quotient and n is the remainder. Here we have two branches:
(1) If b <= m, then ab <= am <= c.
If b = m and n = 0, we have ab == c. Thus the quotient of ab/c is 1 and the remainder is 0.
If n != 0 then ab < c, so the quotient 0 and the reminder is ab.
The algorithm terminates here.
(2) Otherwise b > m. In this case we calculate b / m. Let b = km + l where k is the quotient and l is the remainder.
Thus, ab = a(km + l) = kma + al = k(ma + n – n) + al = k(ma + n) – kn + al = kc – kn + al.
It then follows that ab/c = k + (al – kn) / c.
We know that k does not overflow (since k <= b). Set k aside, and focus on calculating the second term.
First recall that l is the remainder of b / m so l < m, we have al < am <= c. Thus the quotient of al/c is 0 and the remainder is al. This does not overflow as al < c.
Then recall that we also have n < a (since n is the remainder of c / a) and k <= b. Thus, recursively running this algorithm on kn/c will eventually terminate by reaching branch (1). Collect the quotient and remainder as u, v such that kn = uc + v.
With the results above, we know that (al – kn) / c = (al – uc – v) / c = (al – v) / c – u.
If al >= v, then 0 <= al – v < c, so the quotient of ab/c is (k – u) and the remainder is (al – v).
If al < v, then -c < al – v < 0. the quotient is (k – u – 1) and the remainder is (c + al – v).
This concludes the algorithm.
Afterword
With this algorithm we can solve the calculation of a * b / c, and also the mod calculation of a * b % c (where % is the mod operator in C).
For the latter problem it could even solve the situation where the initial passed in a * b / c overflows: the reduction step where we make a = kc + a’ will make sure that ab == a’b (mod c), thus the remainder calculated with this algorithm will remain correct.
Sample code shown as below:
void mulDiv(int a, int b, int c, Vector2& outResult)
{
if (a == 0 || b == 0)
{
outResult.set(0, 0);
return;
}
// Keeping a < b helps us iterate faster.
if (a > b)
{
int t = a;
a = b;
b = t;
}
// Decompose c = m * a + n, and b = k * m + l, so that ab = k * c + (a * l - k * n)
int m = c / a;
int n = c - a * m;
// Early out: if b <= m, then ab <= am <= c
if (b <= m)
{
if (n == 0 && b == m)
{
outResult.set(1, 0);
}
else
{
outResult.set(0, a * b);
}
return;
}
int k = b / m;
int l = b - k * m;
// Thus the result = k + [(a * l - k * n) / c].
int result = k;
// Since a * m < c and l < m, we know a * l < c and thus it doesn't overflow.
int rem = a * l;
// We iteratively calculate k * n / c while accumulating the remainders.
Vector2 remainderResult;
mulDiv(k, n, c, remainderResult);
result -= remainderResult.x;
if (remainderResult.y <= rem)
{
rem -= remainderResult.y;
}
else
{
rem += c - remainderResult.y;
result -= 1;
}
outResult.set(result, rem);
}