本文简单介绍了一种大整数乘法的实现方式
当整数范围较大时,直接使用乘法运算符(*)很容易导致数值溢出,如果开发工作中确实需要处理这种大范围的整数,那么我们便需要实现一下大(范围)整数的乘法运算(一般方法便是将大整数表达为字符串,然后基于字符串来进行乘法运算).
在实现大整数乘法之前,我们先来实现一下大整数的加法运算,朴素方法便是从低到高按位进行加法操作,并考虑进位的影响,代码大概如下(Lua):
local big_int = {}local function digit_num(v)return #tostring(v)
endfunction big_int.add(a, b)a = tostring(a)b = tostring(b)local result_buffer = {}local cur_a_index = -1local cur_b_index = -1local last_carry = 0while true dolocal sub_a_str = a:sub(cur_a_index, cur_a_index)local sub_b_str = b:sub(cur_b_index, cur_b_index)local sub_a = tonumber(sub_a_str)local sub_b = tonumber(sub_b_str)if last_carry == 0 thenif not sub_a and not sub_b thenbreakelseif not sub_a thentable.insert(result_buffer, 1, b:sub(1, cur_b_index))breakelseif not sub_b thentable.insert(result_buffer, 1, a:sub(1, cur_a_index))breakendendsub_a = sub_a or 0sub_b = sub_b or 0local sub_result = sub_a + sub_b + last_carryif sub_result >= 10 thenlast_carry = 1table.insert(result_buffer, 1, tostring(sub_result):sub(2))elselast_carry = 0table.insert(result_buffer, 1, tostring(sub_result))endcur_a_index = cur_a_index - 1cur_b_index = cur_b_index - 1endreturn table.concat(result_buffer)
endreturn big_int
上述代码的基本思路便是按位做加法,但实际上我们可以将"位"的概念扩展一下,延伸为"段",之前我们总是按"位"做加法,现在我们可以按"段"来做加法,这样做的好处便是加法效率会有比较大的提升(注意,这个提升也只是单方面的,并不意味着代码整体效率一定会提升),当然也会带来一些"副作用",譬如前导零问题:
考虑数字 1100 和 2201 相加,我们按两位一段来进行相加,即 11 和 22 相加, 00 和 01 相加,其中 00 和 01 相加的结果为 1(00 和 01 转为数字之后的相加结果),直接转换回字符串(结果为字符串"1",但实际上应为字符串"01")会导致前导零丢失.
处理了相关问题的代码如下(Lua):
local big_int = {}local min_add_digit_num = 9local function digit_num(v)return #tostring(v)
endfunction big_int.add(a, b)a = tostring(a)b = tostring(b)local result_buffer = {}local a_start_index = -min_add_digit_numlocal a_end_index = -1local b_start_index = -min_add_digit_numlocal b_end_index = -1local last_carry = 0while true dolocal sub_a_str = a:sub(a_start_index, a_end_index)local sub_b_str = b:sub(b_start_index, b_end_index)local sub_a = tonumber(sub_a_str)local sub_b = tonumber(sub_b_str)if last_carry == 0 thenif not sub_a and not sub_b thenbreakelseif not sub_a thentable.insert(result_buffer, 1, b:sub(1, b_end_index))breakelseif not sub_b thentable.insert(result_buffer, 1, a:sub(1, a_end_index))breakendendsub_a = sub_a or 0sub_b = sub_b or 0local sub_result = sub_a + sub_b + last_carrylocal sub_result_digit_num = digit_num(sub_result)local sub_max_digit_num = math.max(digit_num(sub_a_str), digit_num(sub_b_str))if sub_max_digit_num >= min_add_digit_num thenif sub_result_digit_num > sub_max_digit_num thenlast_carry = 1table.insert(result_buffer, 1, tostring(sub_result):sub(2))elseif sub_result_digit_num == sub_max_digit_num thenlast_carry = 0table.insert(result_buffer, 1, tostring(sub_result))else-- handling heading zeroslast_carry = 0table.insert(result_buffer, 1, string.rep("0", sub_max_digit_num - sub_result_digit_num) .. tostring(sub_result))endelselast_carry = 0table.insert(result_buffer, 1, tostring(sub_result))enda_end_index = a_start_index - 1a_start_index = a_start_index - min_add_digit_numb_end_index = b_start_index - 1b_start_index = b_start_index - min_add_digit_numendreturn table.concat(result_buffer)
endreturn big_int
简单测试一下,新版本代码的效率大概是老版本的 3 倍左右.
OK,实现了大整数加法,我们接着来实现大整数乘法,实际上来讲,大整数乘法也是可以按位进行乘法然后直接运用大整数加法来解决的,但是这种实现方式效率较差,更好的方法还是运用二分求解:
考虑大整数乘法 a ∗ b a * b a∗b,我们将 a a a 分为高位 a h a_h ah 和低位 a l a_l al,将 b b b 分为高位 b h b_h bh 和低位 b l b_l bl,并设 a l a_l al 的位数为 n n n, b l b_l bl 的位数为 m m m, 则有:
a ∗ b = ( a h ∗ 1 0 n + a l ) ∗ ( b h ∗ 1 0 m + b l ) = a h ∗ b h ∗ 1 0 n + m + a h ∗ b l ∗ 1 0 n + a l ∗ b h ∗ 1 0 m + a l ∗ b l a * b = (a_h * 10^{n} + a_l) * (b_h * 10^{m} + b_l) \\ = a_h * b_h * 10^{n + m} + a_h * b_l * 10^{n} + a_l * b_h * 10^{m} + a_l * b_l a∗b=(ah∗10n+al)∗(bh∗10m+bl)=ah∗bh∗10n+m+ah∗bl∗10n+al∗bh∗10m+al∗bl
其中 a h ∗ b h , a h ∗ b l , a l ∗ b h , a l ∗ b l a_h * b_h, a_h * b_l, a_l * b_h, a_l * b_l ah∗bh,ah∗bl,al∗bh,al∗bl 都是相同的大整数乘法子问题,我们直接递归求解即可,代码大概如下(Lua,重复代码已省略):
local min_mul_digit_num = 5function big_int.mul(a, b)a = tostring(a)b = tostring(b)local a_digit_num = digit_num(a)local b_digit_num = digit_num(b)if a_digit_num + b_digit_num <= 2 * min_mul_digit_num thenreturn tostring((tonumber(a) or 0) * (tonumber(b) or 0))elselocal a_digit_num_h = math.ceil(a_digit_num / 2)local a_digit_num_l = a_digit_num - a_digit_num_hlocal b_digit_num_h = math.ceil(b_digit_num / 2)local b_digit_num_l = b_digit_num - b_digit_num_hlocal ah = a:sub(1, a_digit_num_h)local al = a:sub(a_digit_num_h + 1, -1)local bh = b:sub(1, b_digit_num_h)local bl = b:sub(b_digit_num_h + 1, -1)local ah_mul_bh = big_int.mul(ah, bh)local ah_mul_bl = big_int.mul(ah, bl)local al_mul_bh = big_int.mul(al, bh)local al_mul_bl = big_int.mul(al, bl)local result = ah_mul_bh .. string.rep("0", a_digit_num_l + b_digit_num_l)result = big_int.add(result, ah_mul_bl .. string.rep("0", a_digit_num_l))result = big_int.add(result, al_mul_bh .. string.rep("0", b_digit_num_l))result = big_int.add(result, al_mul_bl)return resultend
end
完整代码(包括一些测试)在这里.