模意义下整数运算 (ModInt) 模板详解

模意义下整数运算 (ModInt) 模板详解

1. 简介

ModInt 是一个用于处理模意义下整数运算的模板类,它可以自动处理模运算,避免溢出,并提供了方便的运算符重载。

主要特点:

  • 支持静态模数和动态模数
  • 自动处理模运算和溢出
  • 支持基本四则运算
  • 提供快速幂和逆元计算
  • 包含 Barrett 优化

2. 实现原理

2.1 基本概念

  1. 模运算

    • 所有运算都在模 P 意义下进行
    • 需要保证结果在 [0, P) 范围内
  2. Barrett 优化

    • 用于优化模乘法运算
    • 避免使用除法运算

2.2 核心策略

  1. 模板设计

    • 使用模板参数指定模数
    • 支持不同整数类型
  2. 运算优化

    • 使用 Barrett 算法优化乘法
    • 快速幂优化幂运算
    • 费马小定理求逆元

3. 模板代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
using u32 = unsigned int;
using u64 = unsigned long long;
using u128 = unsigned __int128;

template<class T>
constexpr T power(T a, u64 b, T res = 1) {
for (; b != 0; b /= 2, a *= a) {
if (b & 1) {
res *= a;
}
}
return res;
}

template<u32 P>
constexpr u32 mulMod(u32 a, u32 b) {
return u64(a) * b % P;
}

template<u64 P>
constexpr u64 mulMod(u64 a, u64 b) {
u64 res = a * b - u64(1.L * a * b / P - 0.5L) * P;
res %= P;
return res;
}

constexpr i64 safeMod(i64 x, i64 m) {
x %= m;
if (x < 0) {
x += m;
}
return x;
}

constexpr std::pair<i64, i64> invGcd(i64 a, i64 b) {
a = safeMod(a, b);
if (a == 0) {
return {b, 0};
}

i64 s = b, t = a;
i64 m0 = 0, m1 = 1;

while (t) {
i64 u = s / t;
s -= t * u;
m0 -= m1 * u;

std::swap(s, t);
std::swap(m0, m1);
}

if (m0 < 0) {
m0 += b / s;
}

return {s, m0};
}

template<std::unsigned_integral U, U P>
struct ModIntBase {
constexpr ModIntBase() : x(0) {}
template<std::unsigned_integral T>
constexpr ModIntBase(T x_) : x(x_ % mod()) {}
template<std::signed_integral T>
constexpr ModIntBase(T x_) {
using S = std::make_signed_t<U>;
S v = x_ % S(mod());
if (v < 0) {
v += mod();
}
x = v;
}

constexpr static U mod() {
return P;
}

constexpr U val() const {
return x;
}

constexpr ModIntBase operator-() const {
ModIntBase res;
res.x = (x == 0 ? 0 : mod() - x);
return res;
}

constexpr ModIntBase inv() const {
return power(*this, mod() - 2);
}

constexpr ModIntBase &operator*=(const ModIntBase &rhs) & {
x = mulMod<mod()>(x, rhs.val());
return *this;
}
constexpr ModIntBase &operator+=(const ModIntBase &rhs) & {
x += rhs.val();
if (x >= mod()) {
x -= mod();
}
return *this;
}
constexpr ModIntBase &operator-=(const ModIntBase &rhs) & {
x -= rhs.val();
if (x >= mod()) {
x += mod();
}
return *this;
}
constexpr ModIntBase &operator/=(const ModIntBase &rhs) & {
return *this *= rhs.inv();
}

friend constexpr ModIntBase operator*(ModIntBase lhs, const ModIntBase &rhs) {
lhs *= rhs;
return lhs;
}
friend constexpr ModIntBase operator+(ModIntBase lhs, const ModIntBase &rhs) {
lhs += rhs;
return lhs;
}
friend constexpr ModIntBase operator-(ModIntBase lhs, const ModIntBase &rhs) {
lhs -= rhs;
return lhs;
}
friend constexpr ModIntBase operator/(ModIntBase lhs, const ModIntBase &rhs) {
lhs /= rhs;
return lhs;
}

friend constexpr std::istream &operator>>(std::istream &is, ModIntBase &a) {
i64 i;
is >> i;
a = i;
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const ModIntBase &a) {
return os << a.val();
}

friend constexpr std::strong_ordering operator<=>(ModIntBase lhs, ModIntBase rhs) {
return lhs.val() <=> rhs.val();
}

private:
U x;
};

template<u32 P>
using ModInt = ModIntBase<u32, P>;
template<u64 P>
using ModInt64 = ModIntBase<u64, P>;

struct Barrett {
Barrett(u32 m_) : m(m_), im((u64)(-1) / m_ + 1) {}

constexpr u32 mod() const {
return m;
}

constexpr u32 mul(u32 a, u32 b) const {
u64 z = a;
z *= b;

u64 x = u64((u128(z) * im) >> 64);

u32 v = u32(z - x * m);
if (m <= v) {
v += m;
}
return v;
}

private:
u32 m;
u64 im;
};

template<u32 Id>
struct DynModInt {
constexpr DynModInt() : x(0) {}
template<std::unsigned_integral T>
constexpr DynModInt(T x_) : x(x_ % mod()) {}
template<std::signed_integral T>
constexpr DynModInt(T x_) {
int v = x_ % int(mod());
if (v < 0) {
v += mod();
}
x = v;
}

constexpr static void setMod(u32 m) {
bt = m;
}

static u32 mod() {
return bt.mod();
}

constexpr u32 val() const {
return x;
}

constexpr DynModInt operator-() const {
DynModInt res;
res.x = (x == 0 ? 0 : mod() - x);
return res;
}

constexpr DynModInt inv() const {
auto v = invGcd(x, mod());
assert(v.first == 1);
return v.second;
}

constexpr DynModInt &operator*=(const DynModInt &rhs) & {
x = bt.mul(x, rhs.val());
return *this;
}
constexpr DynModInt &operator+=(const DynModInt &rhs) & {
x += rhs.val();
if (x >= mod()) {
x -= mod();
}
return *this;
}
constexpr DynModInt &operator-=(const DynModInt &rhs) & {
x -= rhs.val();
if (x >= mod()) {
x += mod();
}
return *this;
}
constexpr DynModInt &operator/=(const DynModInt &rhs) & {
return *this *= rhs.inv();
}

friend constexpr DynModInt operator*(DynModInt lhs, const DynModInt &rhs) {
lhs *= rhs;
return lhs;
}
friend constexpr DynModInt operator+(DynModInt lhs, const DynModInt &rhs) {
lhs += rhs;
return lhs;
}
friend constexpr DynModInt operator-(DynModInt lhs, const DynModInt &rhs) {
lhs -= rhs;
return lhs;
}
friend constexpr DynModInt operator/(DynModInt lhs, const DynModInt &rhs) {
lhs /= rhs;
return lhs;
}

friend constexpr std::istream &operator>>(std::istream &is, DynModInt &a) {
i64 i;
is >> i;
a = i;
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, const DynModInt &a) {
return os << a.val();
}

friend constexpr std::strong_ordering operator<=>(DynModInt lhs, DynModInt rhs) {
return lhs.val() <=> rhs.val();
}

private:
u32 x;
static Barrett bt;
};

template<u32 Id>
Barrett DynModInt<Id>::bt = 998244353;

using Z = ModInt<998244353>;

4. 函数说明

4.1 基础函数

  • power(): 快速幂运算
  • mulMod(): 模乘法运算
  • safeMod(): 安全的取模运算
  • invGcd(): 扩展欧几里得算法求逆元

4.2 ModInt 类

  • 构造函数:支持各种整数类型
  • 基本运算:加减乘除和取负
  • 特殊运算:求逆元、快速幂
  • 比较运算:完整的比较运算符

5. 时间复杂度分析

  • 基本运算:O(1)
  • 快速幂:O(log n)
  • 求逆元:O(log m),m 为模数

6. 应用场景

  1. 组合数学计算
  2. 多项式运算
  3. 动态规划
  4. 矩阵运算

7. 使用示例

7.1 静态模数

1
2
3
ModInt<998244353> a = 1000000007;
ModInt<998244353> b = 2000000014;
auto c = a * b; // 自动处理模运算

7.2 动态模数

1
2
3
4
5
using Mint = DynModInt<0>;
Mint::setMod(1000000007);
Mint a = 1000000;
Mint b = 2000000;
auto c = a + b; // 在新模数下运算

8. 注意事项

  1. 模数的选择

    • 静态模数必须是编译期常量
    • 动态模数可以在运行时修改
  2. 性能考虑

    • 静态模数性能更好
    • Barrett 优化用于提高乘法性能
  3. 数值范围

    • 注意中间结果不要溢出
    • 选择合适的整数类型

9. 总结

ModInt 模板提供了一个高效、安全的模运算框架,支持静态和动态模数,并通过 Barrett 优化提高了性能。模板的设计既保证了易用性,又维持了高效性,是处理模运算问题的有力工具。