0%

利用快速幂模运算加速RSA加密解密过程

简介

之前实现RSA加密算法时,计算幂成了程序的瓶颈,前段时间了解了快速幂以及快速幂模运算,这种算法可以用于加速RSA加密解密过程中的幂模计算过程。

快速幂

基本原理

用精确的数学符号表示的话,有下面两种写法:

$$ a ^ b = {\begin{cases}(a ^ {\frac b 2}) ^ 2, &if\ b\ is\ even\\a \cdot (a ^ {\frac {b-1} 2}) ^ 2, &if\ b\ is\ odd\end{cases}}$$

$$ a ^ b = {\begin{cases}(a ^ {\lfloor {\frac b 2} \rfloor}) ^ 2, &if\ b\ is\ even\\a \cdot (a ^ {\lfloor {\frac b 2} \rfloor}) ^ 2, &if\ b\ is\ odd\end{cases}}$$

上述公式显然成立,故计算幂可用上述递推式进行。

而在编程语言中,整除一般就相当于向下取整,故上述递推式可以表示成下面的伪代码:

1
2
3
4
5
6
7
func pow(a, b):
if b is 0:
return 1
if b is even:
return pow(a, b/2) ^ 2
if b is odd:
return pow(a, b/2) ^ 2 * a

时间复杂度

对于 $a^b$,用原始的求幂算法,需要计算 $b$ 次乘法,用上述递推式大约需要运算 $log\ b$ 次

代码实现

下面的代码实现了朴素求幂算法和快速幂算法,并对比了他们的运算时间。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import time


def calc_time(func):
def inner_func(*args, **kwargs):
start_time = time.time()
func(*args, **kwargs)
end_time = time.time()
print(f"Time used: {end_time - start_time}")
return inner_func


@calc_time
def prime_pow(base, power):
res = 1
for _ in range(power):
res *= base
# print(f'The last five digits of the result {base} ^ {power}: {str(res)[-5:]}')


def quick_pow(base, power):
if power == 0:
return 1
res = quick_pow(base, power // 2)
if power % 2 == 0:
return res * res
else:
return res * res * base


@calc_time
def quick_pow_wrapper(base, power):
res = quick_pow(base, power)
# print(f'The last five digits of the result of {base} ^ {power}: {str(res)[-5:]}')


prime_pow(2, 1000000) # Time used: 15.331219911575317
quick_pow_wrapper(2, 1000000) # Time used: 0.003216981887817383
quick_pow_wrapper(13789, 722341) # Time used: 0.003216981887817383
prime_pow(13789, 722341) # long long long long no result

快速幂模

基本原理

求模运算有这样一条性质:

利用这一性质,结合快速幂递推式可得:

$$ a ^ b \% m = {\begin{cases}(a ^ {\lfloor {\frac b 2} \rfloor} \% m) ^ 2 \%m, &if\ b\ is\ even\\(a \%m \cdot (a ^ {\lfloor {\frac b 2} \rfloor} \%m) ^ 2) \%m, &if\ b\ is\ odd\end{cases}}$$

表示成伪代码如下所示:

1
2
3
4
5
6
7
func pow_mod(a, b, m):
if b is 0:
return 1
if b is even:
return pow_mod(a, b/2, m) ^ 2 % m
if b is odd:
return (a % m * pow_mod(a, b/2, m) ^ 2) % m

代码实现

1
2
3
4
5
6
7
8
9
10
def pow_mod(a, b, m):
if b == 0:
return 1
res = pow_mod(a, b // 2, m)
if b % 2 == 0:
return res * res % m
else:
return res * res * (a % m) % m

print(pow_mod(2147483647, 200, 1337))

加速RSA加密解密过程

RSA加密和解密需要计算大数幂的模,用原始计算方法难以进行,可以用快速幂模运算加速运算过程。

下面的代码是对 2021-05-23-RSA加密算法原理 中RSA密钥生成、加密、解密的DEMO的改进。

具体代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from typing import Tuple


class RSAKey:
class RSAParams:
def __init__(self, p, q, N, phiN, e, d) -> None:
self.p = p
self.q = q
self.N = N
self.phiN = phiN
self.e = e
self.d = d

def __str__(self) -> str:
return f"""{{
p: {self.p}\n
q: {self.q}\n
N: {self.N}\n
phiN: {self.phiN}\n
e: {self.e}\n
d: {self.d}
}}"""

def __init__(self) -> None:
self.params = self.__generateRSAParams()

def getPublicKey(self) -> Tuple[int, int]:
"""get public key

Returns:
Tuple[int, int]: (N, e)
"""
return (self.params.N, self.params.e)

def getPrivateKey(self) -> Tuple[int, int]:
"""get private key

Returns:
Tuple[int, int]: (N, d)
"""
return (self.params.N, self.params.d)

def __generateRSAParams(self) -> RSAParams:
"""Generate parameters for RSA

Returns:
RSAParams: instance of RSAParams
"""
# 产生两个大素数
from Crypto.Util import number
p: int = number.getPrime(476)
q: int = number.getPrime(476)

# p 与 q 不能相等
while q == p:
q: int = number.getPrime(10)

# N = p x q
N: int = p * q

# phi(N)
phiN: int = (p - 1) * (q - 1)

# 取公钥参数 e,e 应小于 phi(N) 且与 phi(N) 互质
# 一种简单的思路是找到一个质数 e,只要 phi(N) 不是它的倍数即可
e: int = number.getPrime(16)
while phiN % e == 0:
e: int = number.getPrime(16)

# 使用扩展欧几里得算法求解 e 的模逆元 d
_, d, _ = self.__exgcd(e, phiN)

# 如果计算得到的 d 是负数,则加上 phi(N) 将其转为正数,仍然与 phi(N) 保持同余关系
if d < 0:
d = d + phiN

return self.RSAParams(p, q, N, phiN, e, d)

def __exgcd(self, a, b):
"""扩展欧几里得算法

Args:
a (int): a
b (int): b

Returns:
int: (gcd, x, y)
"""
ri: int = a
rj: int = b
si: int = 1
sj: int = 0
ti: int = 0
tj: int = 1

while rj != 0:
qi = ri // rj

rtemp = rj
rj = ri - qi * rj
ri = rtemp

stemp = sj
sj = si - qi * sj
si = stemp

ttemp = tj
tj = ti - qi * tj
ti = ttemp

return ri, si, ti


def pow_mod(a: int, b: int, m: int) -> int:
"""fast power module

Args:
a : a ^ b % m
b : a ^ b % m
m : a ^ b % m

Returns:
int: a ^ b % m
"""
if b == 0:
return 1
res = pow_mod(a, b // 2, m)
if b % 2 == 0:
return res * res % m
else:
return res * res * (a % m) % m


def encrypt(m, publicKey) -> int:
"""encrypt message using public key

Args:
m (int): origin message m
publicKey (Tuple[int, int]): (N, e)

Returns:
int: encrypted message c
"""
N, e = publicKey
c = pow_mod(m, e, N)
return c


def decrypt(c, privateKey) -> int:
"""decrypt message using private key

Args:
c (int): encrypted message c
privateKey (Tuple[int, int]): (N, d)

Returns:
int: origin message m
"""
N, d = privateKey
# m = c ** d % N
m = pow_mod(c, d, N)
return m


def main():
rsakey = RSAKey()
publicKey = rsakey.getPublicKey()
privateKey = rsakey.getPrivateKey()

print(f"publicKey: {publicKey}")
print(f"privateKey: {privateKey}")

message = 123456789
encryptedMessage = encrypt(message, publicKey)
decryptedMessage = decrypt(encryptedMessage, privateKey)

print(f"origin message: {message}")
print(f"encrypted message: {encryptedMessage}")
print(f"decrypted message: {decryptedMessage}")


if __name__ == '__main__':
main()