pymodel

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import sys
sys.setrecursionlimit(100000)
input=lambda:sys.stdin.readline().strip()
# write=lambda x:sys.stdout.write(str(x)+'\n')
# from decimal import Decimal
# from datetime import datetime,timedelta
# from random import randint
# from copy import deepcopy
from collections import *
# from heapq import heapify,heappush,heappop
# from bisect import bisect_left,bisect,insort
# from math import inf,sqrt,gcd,pow,ceil,floor,log,log2,log10,pi,sin,cos,tan,asin,acos,atan
# from functools import cmp_to_key,reduce
# from operator import or_,xor,add,mul
# from itertools import permutations,combinations,accumulate
sint = lambda: int(input())
mint = lambda: map(int, input().split())
lint = lambda: list(map(int, input().split()))

def solve():


if __name__ == '__main__':
#t=int(input())
#for _ in range(t):
# solve()

solve()

并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class DSU:
def __init__(self, n):
self.parent = list(range(n))
self.size = [1] * n
self.n = n
self.setCount = n

def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
x, y = self.find(x), self.find(y)
if x == y:
return False
if self.size[x] < self.size[y]:
x, y = y, x
self.parent[y] = x
self.size[x] += self.size[y]
self.setCount -= 1
return True

def connected(self, x, y):
x, y = self.find(x), self.find(y)
return x == y

组合数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Factorial:
def __init__(self, N, mod) -> None:
N += 1
self.mod = mod
self.f = [1 for _ in range(N)]
self.g = [1 for _ in range(N)]
for i in range(1, N):
self.f[i] = self.f[i - 1] * i % self.mod
self.g[-1] = pow(self.f[-1], mod - 2, mod)
for i in range(N - 2, -1, -1):
self.g[i] = self.g[i + 1] * (i + 1) % self.mod

def fac(self, n):
return self.f[n]

def fac_inv(self, n):
return self.g[n]

def comb(self, n, m):
if n < m or m < 0 or n < 0: return 0
return self.f[n] * self.g[m] % self.mod * self.g[n - m] % self.mod

def permu(self, n, m):
if n < m or m < 0 or n < 0: return 0
return self.f[n] * self.g[n - m] % self.mod

def catalan(self, n):
return (self.comb(2 * n, n) - self.comb(2 * n, n - 1)) % self.mod

def inv(self, n):
return self.f[n - 1] * self.g[n] % self.mod

质数筛

1
2
3
4
5
6
7
8
9
10
11
12
13
pri = []
not_prime = [False] * N

def pre(n):
for i in range(2, n + 1):
if not not_prime[i]:
pri.append(i)
for pri_j in pri:
if i * pri_j > n:
break
not_prime[i * pri_j] = True
if i % pri_j == 0:
break

质因数个数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def divide(n):
ans = []
i = 2
while i <= n // i:
if n % i == 0:
cnt = 0
while n % i == 0:
cnt += 1
n //= i
ans.append((i, cnt))
i += 1
if n > 1:
ans.append((n, 1))
return ans

约数之和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
mod = int(1e9 + 7)
primes = {}
a = int(input())
while a:
a -= 1
n = int(input())
i = 2
while i <= n // i:
while n % i == 0:
n //= i
if i in primes:
primes[i] += 1
else:
primes[i] = 1
i += 1
if n > 1:
if n in primes:
primes[n] += 1
else:
primes[n] = 1
res = 1
for i,val in primes.items():
t = 1
while val:
val -= 1
t = (t * i + 1) % mod
res = res * t % mod
print(res)

约数个数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
mod = 1e9 + 7

primes = {}
a = int(input())
while a:
a -= 1
n = int(input())
i = 2
while i <= n // i:
while n % i == 0:
n //= i
if i in primes:
primes[i] += 1
else:
primes[i] = 1
i += 1
if n > 1:
if n in primes:
primes[n] += 1
else:
primes[n] = 1

res = 1
for i in primes.values():
res = int(res * (i + 1) % mod)
print(res)

欧拉函数

alt text

1
2
3
4
5
6
7
8
9
10
11
12
13
14
n = int(input())
for i in range(n):
a = int(input())
res = a
j = 2
while j * j <= a :
if a % j == 0:
res = res * (j - 1) // j
while a % j == 0:
a = a // j
j += 1
if a > 1:
res = res * (a - 1) // a
print(int(res))

筛法求欧拉函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
N = 1000010
primes = [0]*N
phi = [0]*N
st = [False]*N

def get_eulers(n):
phi[1] = 1
cnt = 0
for i in range(2,n+1):
if not st[i]:
primes[cnt] = i
cnt += 1
phi[i] = i - 1
j = 0
while primes[j] <= n // i:
st[primes[j] * i] = True
if i % primes[j] == 0:
phi[primes[j] * i] = phi[i] * primes[j]
break
phi[primes[j] * i] = phi[i] * (primes[j] - 1)
j += 1

扩展欧几里得

alt text

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def extend_gcd(a,b,x,y):
if not b:
return a,1,0

d,y,x = extend_gcd(b, a % b, y, x)
y -= a // b * x
return d,x,y

n = int(input())
while n:
n -= 1
a,b = map(int,input().split())
x,y = 0,0
d,x,y = extend_gcd(a,b,x,y)
print(x,y)

线性同余方程

20240518200819

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def extend_gcd(a,b,x,y):
if not b:
return a,1,0

d,y,x = extend_gcd(b, a % b, y, x)
y -= a // b * x
return d,x,y

n = int(input())
while n:
n -= 1
a,b,m = map(int,input().split())
x,y = 0,0
d,x,y = extend_gcd(a,m,x,y)
if b % d:
print("impossible")
else:
print(x * (b // d) % m)

求组合数

1
2
3
4
5
import math
def C(n, m):
return math.factorial(n) // (math.factorial(m) * math.factorial(n - m))
n, m = map(int, input().split())
print(C(n, m))

数据结构

树状数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 左闭右闭
class FenWick:
def __init__(self, n: int):
self.n = n
self.tr = [0 for _ in range(n + 1)]

def sum(self, i: int):
i += 1
s = 0
while i >= 1:
s += self.tr[i]
i &= i - 1
return s

def rangeSum(self, l: int, r: int):
return self.sum(r) - self.sum(l - 1)

def add(self, i: int, val: int):
i += 1
while i <= self.n:
self.tr[i] += val
i += i & -i

线段数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
from collections import defaultdict
# region fastio
input = lambda: sys.stdin.readline().rstrip()
sint = lambda: int(input())
mint = lambda: map(int, input().split())
ints = lambda: list(map(int, input().split()))
class segtree:

def __init__(self, n: int, m: int):
self.n = n
self.m = m
self.max = [0] * (4*n) # 维护区间最大值
self.sum = [0] * (4*n) # 维护区间和
# self.build(m, 1, n, 1) # 建树【线段树下标从1开始】

# '''自定义build函数:建树'''
# def build(self, val, left, right, root):
# if left == right: # 到达叶子节点,递归终止
# self.max[root] = val
# self.sum[root] = val
# return
# mid = (left + right) // 2
# self.build(val, left, mid, 2*root) # 左右递归
# self.build(val, mid+1, right, 2*root+1)
# self.pushUp(root) # 更新信息

'''自定义pushUp函数:更新节点信息'''
def pushUp(self, root):
self.max[root] = max(self.max[2*root], self.max[2*root+1])
self.sum[root] = self.sum[2*root] + self.sum[2*root+1]

'''自定义add函数:单点更新'''
def add(self, i: int, val: int, left: int, right: int, root: int) -> None:
# i表示操作位置编号; root表示当前节点编号(树中),[left, right]表示root节点所维护的区间
if left == right: # 到达叶子节点,递归终止
self.max[root] += val # root节点更新,+val
self.sum[root] += val # root节点更新,+val
return
mid = (left + right) // 2
if i <= mid: # 根据条件判断去往左/右子树
self.add(i, val, left, mid, 2*root)
else:
self.add(i, val, mid+1, right, 2*root+1)
self.pushUp(root) # 子节点更新了,父节点信息也需要更新

'''自定义query函数:区间和查询'''
# 查询[L, R]之间合
# 在区间[left, right]中查询区间[L,R]中的数值之和【求和】
def query(self, L: int, R: int, left: int, right: int, root: int):
# L,R表示操作区间; root表示当前节点编号(树中),[left, right]表示root节点所维护的区间
if L<=left and right<=R:
return self.sum[root]
mid = (left+right) // 2
range_sum = 0
if L <= mid: # 左区间有重合
range_sum += self.query(L, R, left, mid, 2*root)
if R > mid: # 右区间有重合
range_sum += self.query(L, R, mid+1, right, 2*root+1)
return range_sum

#查询[L, R]之间的最大值
#qeery_max(L,R,1,N,1)
def query_max(self, L:int, R: int, left :int, right: int, root : int):
if L <= left and right <= R:
return self.max[root]
mid = (left + right) // 2
mm = 0
if L <= mid:
mm = self.query_max(L, R, left, mid, 2 * root)
if R > mid:
mm = max(mm, self.query_max(L, R, mid + 1, right, 2 * root + 1))
return mm

def solve():
m, p = mint()
se = segtree(m, 0)
last = 0
n = 0
for _ in range(m):
op, l = input().split()
l = int(l)
if op == 'A':
# print("val", (last + l) % p)
se.add(n + 1, (last + l) % p, 1, m, 1)
n += 1
else:
last = se.query_max(n - l + 1, n, 1, m, 1 )
print(last)
if __name__ == '__main__':

solve()

SortedList

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

class SortedList:
def __init__(self, iterable=[], _load=200):
"""Initialize sorted list instance."""
values = sorted(iterable)
self._len = _len = len(values)
self._load = _load
self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
self._list_lens = [len(_list) for _list in _lists]
self._mins = [_list[0] for _list in _lists]
self._fen_tree = []
self._rebuild = True

def _fen_build(self):
"""Build a fenwick tree instance."""
self._fen_tree[:] = self._list_lens
_fen_tree = self._fen_tree
for i in range(len(_fen_tree)):
if i | i + 1 < len(_fen_tree):
_fen_tree[i | i + 1] += _fen_tree[i]
self._rebuild = False

def _fen_update(self, index, value):
"""Update `fen_tree[index] += value`."""
if not self._rebuild:
_fen_tree = self._fen_tree
while index < len(_fen_tree):
_fen_tree[index] += value
index |= index + 1

def _fen_query(self, end):
"""Return `sum(_fen_tree[:end])`."""
if self._rebuild:
self._fen_build()

_fen_tree = self._fen_tree
x = 0
while end:
x += _fen_tree[end - 1]
end &= end - 1
return x

def _fen_findkth(self, k):
"""Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
_list_lens = self._list_lens
if k < _list_lens[0]:
return 0, k
if k >= self._len - _list_lens[-1]:
return len(_list_lens) - 1, k + _list_lens[-1] - self._len
if self._rebuild:
self._fen_build()

_fen_tree = self._fen_tree
idx = -1
for d in reversed(range(len(_fen_tree).bit_length())):
right_idx = idx + (1 << d)
if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
idx = right_idx
k -= _fen_tree[idx]
return idx + 1, k

def _delete(self, pos, idx):
"""Delete value at the given `(pos, idx)`."""
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens

self._len -= 1
self._fen_update(pos, -1)
del _lists[pos][idx]
_list_lens[pos] -= 1

if _list_lens[pos]:
_mins[pos] = _lists[pos][0]
else:
del _lists[pos]
del _list_lens[pos]
del _mins[pos]
self._rebuild = True

def _loc_left(self, value):
"""Return an index pair that corresponds to the first position of `value` in the sorted list."""
if not self._len:
return 0, 0

_lists = self._lists
_mins = self._mins

lo, pos = -1, len(_lists) - 1
while lo + 1 < pos:
mi = (lo + pos) >> 1
if value <= _mins[mi]:
pos = mi
else:
lo = mi

if pos and value <= _lists[pos - 1][-1]:
pos -= 1

_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value <= _list[mi]:
idx = mi
else:
lo = mi

return pos, idx

def _loc_right(self, value):
"""Return an index pair that corresponds to the last position of `value` in the sorted list."""
if not self._len:
return 0, 0

_lists = self._lists
_mins = self._mins

pos, hi = 0, len(_lists)
while pos + 1 < hi:
mi = (pos + hi) >> 1
if value < _mins[mi]:
hi = mi
else:
pos = mi

_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value < _list[mi]:
idx = mi
else:
lo = mi

return pos, idx

def add(self, value):
"""Add `value` to sorted list."""
_load = self._load
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens

self._len += 1
if _lists:
pos, idx = self._loc_right(value)
self._fen_update(pos, 1)
_list = _lists[pos]
_list.insert(idx, value)
_list_lens[pos] += 1
_mins[pos] = _list[0]
if _load + _load < len(_list):
_lists.insert(pos + 1, _list[_load:])
_list_lens.insert(pos + 1, len(_list) - _load)
_mins.insert(pos + 1, _list[_load])
_list_lens[pos] = _load
del _list[_load:]
self._rebuild = True
else:
_lists.append([value])
_mins.append(value)
_list_lens.append(1)
self._rebuild = True

def discard(self, value):
"""Remove `value` from sorted list if it is a member."""
_lists = self._lists
if _lists:
pos, idx = self._loc_right(value)
if idx and _lists[pos][idx - 1] == value:
self._delete(pos, idx - 1)

def remove(self, value):
"""Remove `value` from sorted list; `value` must be a member."""
_len = self._len
self.discard(value)
if _len == self._len:
raise ValueError('{0!r} not in list'.format(value))

def pop(self, index=-1):
"""Remove and return value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
value = self._lists[pos][idx]
self._delete(pos, idx)
return value

def bisect_left(self, value):
"""Return the first index to insert `value` in the sorted list."""
pos, idx = self._loc_left(value)
return self._fen_query(pos) + idx

def bisect_right(self, value):
"""Return the last index to insert `value` in the sorted list."""
pos, idx = self._loc_right(value)
return self._fen_query(pos) + idx

def count(self, value):
"""Return number of occurrences of `value` in the sorted list."""
return self.bisect_right(value) - self.bisect_left(value)

def __len__(self):
"""Return the size of the sorted list."""
return self._len

def __getitem__(self, index):
"""Lookup value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
return self._lists[pos][idx]

def __delitem__(self, index):
"""Remove value at `index` from sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
self._delete(pos, idx)

def __contains__(self, value):
"""Return true if `value` is an element of the sorted list."""
_lists = self._lists
if _lists:
pos, idx = self._loc_left(value)
return idx < len(_lists[pos]) and _lists[pos][idx] == value
return False

def __iter__(self):
"""Return an iterator over the sorted list."""
return (value for _list in self._lists for value in _list)

def __reversed__(self):
"""Return a reverse iterator over the sorted list."""
return (value for _list in reversed(self._lists) for value in reversed(_list))

def __repr__(self):
"""Return string representation of sorted list."""
return 'SortedList({0})'.format(list(self))

字符串哈希

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class StringHash:
def __init__(self, s):
n = len(s)
self.base = 131
self.mod = 10 ** 13 + 7
self.h = h = [0] * (n + 1)
self.p = p = [1] * (n + 1)
for i in range(1, n + 1):
p[i] = p[i - 1] * self.base % self.mod
h[i] = h[i - 1] * self.base + ord(s[i - 1])
h[i] %= self.mod

def get_hash(self, l, r):
res = self.h[r] - self.h[l] * self.p[r - l]
return res % self.mod

二维差分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
for i in range(1, n + 1):
for j in range(1, m + 1):
b[i][j] = a[i][j] - a[i - 1][j] - a[i][j - 1] + a[i - 1][j - 1]
for i in range(q):
x1, y1, x2, y2, c = mint()
b[x1][y1] += c
b[x1][y2 + 1] -= c
b[x2 + 1][y1] -= c
b[x2 + 1][y2 + 1] += c
for i in range(1, n + 1):
for j in range(1, m + 1):
a[i][j] = b[i][j] + a[i - 1][j] + a[i][j - 1] - a[i - 1][j - 1]
print(a[i][j], end = ' ')
print()

图论

spfa

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def spfa(u, n, g) -> int:
q = deque()
q.append(u)
vis = set()
dist = defaultdict(lambda : inf)
vis.add(u)
dist[u] = 0
while q:
t = q.popleft()
vis.remove(t)
for j, d in g[t]:
if dist[j] > dist[t] + d:
dist[j] = dist[t] + d
if j not in vis:
vis.add(j)
q.append(j)
return dist[n]

dijkstra

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def dij(u, n, g) -> int:
q = [(0, u)] # 距离 顶点
vis = set()
dist = defaultdict(lambda : inf)
dist[1] = 0
while q:
d, u = heappop(q)
if u in vis:
continue
vis.add(u)
for j, d in g[u]:
if j not in vis and dist[j] > dist[u] + d:
dist[j] = dist[u] + d
heappush(q, (dist[j], j))
return dist[n]

krushkal

1
2
3
4
5
6
7
8
9
10
def kruskal():
dsu = DSU(n)
edges.sort(key = lambda x : x[2])
res = 0
for u, v, w in edges:
if dsu.same(u, v):
continue
dsu.merge(u, v)
res += w
return res if dsu.n == 1 else inf