import sys sys.setrecursionlimit(100000) input=lambda:sys.stdin.readline().strip() # write=lambda x:sys.stdout.write(str(x)+'\n') # from decimal import Decimal # from datetime import datetime,timedelta # from random import randint # from copy import deepcopy from collections import * # from heapq import heapify,heappush,heappop # from bisect import bisect_left,bisect,insort # from math import inf,sqrt,gcd,pow,ceil,floor,log,log2,log10,pi,sin,cos,tan,asin,acos,atan # from functools import cmp_to_key,reduce # from operator import or_,xor,add,mul # from itertools import permutations,combinations,accumulate sint = lambda: int(input()) mint = lambda: map(int, input().split()) lint = lambda: list(map(int, input().split()))
defsolve():
if __name__ == '__main__': #t=int(input()) #for _ in range(t): # solve()
defunion(self, x, y): x, y = self.find(x), self.find(y) if x == y: returnFalse if self.size[x] < self.size[y]: x, y = y, x self.parent[y] = x self.size[x] += self.size[y] self.setCount -= 1 returnTrue
defconnected(self, x, y): x, y = self.find(x), self.find(y) return x == y
defpre(n): for i inrange(2, n + 1): ifnot not_prime[i]: pri.append(i) for pri_j in pri: if i * pri_j > n: break not_prime[i * pri_j] = True if i % pri_j == 0: break
质因数个数
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defdivide(n): ans = [] i = 2 while i <= n // i: if n % i == 0: cnt = 0 while n % i == 0: cnt += 1 n //= i ans.append((i, cnt)) i += 1 if n > 1: ans.append((n, 1)) return ans
mod = int(1e9 + 7) primes = {} a = int(input()) while a: a -= 1 n = int(input()) i = 2 while i <= n // i: while n % i == 0: n //= i if i in primes: primes[i] += 1 else: primes[i] = 1 i += 1 if n > 1: if n in primes: primes[n] += 1 else: primes[n] = 1 res = 1 for i,val in primes.items(): t = 1 while val: val -= 1 t = (t * i + 1) % mod res = res * t % mod print(res)
primes = {} a = int(input()) while a: a -= 1 n = int(input()) i = 2 while i <= n // i: while n % i == 0: n //= i if i in primes: primes[i] += 1 else: primes[i] = 1 i += 1 if n > 1: if n in primes: primes[n] += 1 else: primes[n] = 1
res = 1 for i in primes.values(): res = int(res * (i + 1) % mod) print(res)
欧拉函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14
n = int(input()) for i inrange(n): a = int(input()) res = a j = 2 while j * j <= a : if a % j == 0: res = res * (j - 1) // j while a % j == 0: a = a // j j += 1 if a > 1: res = res * (a - 1) // a print(int(res))
N = 1000010 primes = [0]*N phi = [0]*N st = [False]*N
defget_eulers(n): phi[1] = 1 cnt = 0 for i inrange(2,n+1): ifnot st[i]: primes[cnt] = i cnt += 1 phi[i] = i - 1 j = 0 while primes[j] <= n // i: st[primes[j] * i] = True if i % primes[j] == 0: phi[primes[j] * i] = phi[i] * primes[j] break phi[primes[j] * i] = phi[i] * (primes[j] - 1) j += 1
扩展欧几里得
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
defextend_gcd(a,b,x,y): ifnot b: return a,1,0
d,y,x = extend_gcd(b, a % b, y, x) y -= a // b * x return d,x,y
n = int(input()) while n: n -= 1 a,b = map(int,input().split()) x,y = 0,0 d,x,y = extend_gcd(a,b,x,y) print(x,y)
线性同余方程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defextend_gcd(a,b,x,y): ifnot b: return a,1,0
d,y,x = extend_gcd(b, a % b, y, x) y -= a // b * x return d,x,y
n = int(input()) while n: n -= 1 a,b,m = map(int,input().split()) x,y = 0,0 d,x,y = extend_gcd(a,m,x,y) if b % d: print("impossible") else: print(x * (b // d) % m)
求组合数
1 2 3 4 5
import math defC(n, m): return math.factorial(n) // (math.factorial(m) * math.factorial(n - m)) n, m = map(int, input().split()) print(C(n, m))
#查询[L, R]之间的最大值 #qeery_max(L,R,1,N,1) defquery_max(self, L:int, R: int, left :int, right: int, root : int): if L <= left and right <= R: return self.max[root] mid = (left + right) // 2 mm = 0 if L <= mid: mm = self.query_max(L, R, left, mid, 2 * root) if R > mid: mm = max(mm, self.query_max(L, R, mid + 1, right, 2 * root + 1)) return mm
defsolve(): m, p = mint() se = segtree(m, 0) last = 0 n = 0 for _ inrange(m): op, l = input().split() l = int(l) if op == 'A': # print("val", (last + l) % p) se.add(n + 1, (last + l) % p, 1, m, 1) n += 1 else: last = se.query_max(n - l + 1, n, 1, m, 1 ) print(last) if __name__ == '__main__':
classSortedList: def__init__(self, iterable=[], _load=200): """Initialize sorted list instance.""" values = sorted(iterable) self._len = _len = len(values) self._load = _load self._lists = _lists = [values[i:i + _load] for i inrange(0, _len, _load)] self._list_lens = [len(_list) for _listin _lists] self._mins = [_list[0] for _listin _lists] self._fen_tree = [] self._rebuild = True
def_fen_build(self): """Build a fenwick tree instance.""" self._fen_tree[:] = self._list_lens _fen_tree = self._fen_tree for i inrange(len(_fen_tree)): if i | i + 1 < len(_fen_tree): _fen_tree[i | i + 1] += _fen_tree[i] self._rebuild = False
def_fen_update(self, index, value): """Update `fen_tree[index] += value`.""" ifnot self._rebuild: _fen_tree = self._fen_tree while index < len(_fen_tree): _fen_tree[index] += value index |= index + 1
def_fen_query(self, end): """Return `sum(_fen_tree[:end])`.""" if self._rebuild: self._fen_build()
_fen_tree = self._fen_tree x = 0 while end: x += _fen_tree[end - 1] end &= end - 1 return x
def_fen_findkth(self, k): """Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`).""" _list_lens = self._list_lens if k < _list_lens[0]: return0, k if k >= self._len - _list_lens[-1]: returnlen(_list_lens) - 1, k + _list_lens[-1] - self._len if self._rebuild: self._fen_build()
_fen_tree = self._fen_tree idx = -1 for d inreversed(range(len(_fen_tree).bit_length())): right_idx = idx + (1 << d) if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]: idx = right_idx k -= _fen_tree[idx] return idx + 1, k
def_delete(self, pos, idx): """Delete value at the given `(pos, idx)`.""" _lists = self._lists _mins = self._mins _list_lens = self._list_lens
self._len -= 1 self._fen_update(pos, -1) del _lists[pos][idx] _list_lens[pos] -= 1
if _list_lens[pos]: _mins[pos] = _lists[pos][0] else: del _lists[pos] del _list_lens[pos] del _mins[pos] self._rebuild = True
def_loc_left(self, value): """Return an index pair that corresponds to the first position of `value` in the sorted list.""" ifnot self._len: return0, 0
_lists = self._lists _mins = self._mins
lo, pos = -1, len(_lists) - 1 while lo + 1 < pos: mi = (lo + pos) >> 1 if value <= _mins[mi]: pos = mi else: lo = mi
if pos and value <= _lists[pos - 1][-1]: pos -= 1
_list = _lists[pos] lo, idx = -1, len(_list) while lo + 1 < idx: mi = (lo + idx) >> 1 if value <= _list[mi]: idx = mi else: lo = mi
return pos, idx
def_loc_right(self, value): """Return an index pair that corresponds to the last position of `value` in the sorted list.""" ifnot self._len: return0, 0
_lists = self._lists _mins = self._mins
pos, hi = 0, len(_lists) while pos + 1 < hi: mi = (pos + hi) >> 1 if value < _mins[mi]: hi = mi else: pos = mi
_list = _lists[pos] lo, idx = -1, len(_list) while lo + 1 < idx: mi = (lo + idx) >> 1 if value < _list[mi]: idx = mi else: lo = mi
defdiscard(self, value): """Remove `value` from sorted list if it is a member.""" _lists = self._lists if _lists: pos, idx = self._loc_right(value) if idx and _lists[pos][idx - 1] == value: self._delete(pos, idx - 1)
defremove(self, value): """Remove `value` from sorted list; `value` must be a member.""" _len = self._len self.discard(value) if _len == self._len: raise ValueError('{0!r} not in list'.format(value))
defpop(self, index=-1): """Remove and return value at `index` in sorted list.""" pos, idx = self._fen_findkth(self._len + index if index < 0else index) value = self._lists[pos][idx] self._delete(pos, idx) return value
defbisect_left(self, value): """Return the first index to insert `value` in the sorted list.""" pos, idx = self._loc_left(value) return self._fen_query(pos) + idx
defbisect_right(self, value): """Return the last index to insert `value` in the sorted list.""" pos, idx = self._loc_right(value) return self._fen_query(pos) + idx
defcount(self, value): """Return number of occurrences of `value` in the sorted list.""" return self.bisect_right(value) - self.bisect_left(value)
def__len__(self): """Return the size of the sorted list.""" return self._len
def__getitem__(self, index): """Lookup value at `index` in sorted list.""" pos, idx = self._fen_findkth(self._len + index if index < 0else index) return self._lists[pos][idx]
def__delitem__(self, index): """Remove value at `index` from sorted list.""" pos, idx = self._fen_findkth(self._len + index if index < 0else index) self._delete(pos, idx)
def__contains__(self, value): """Return true if `value` is an element of the sorted list.""" _lists = self._lists if _lists: pos, idx = self._loc_left(value) return idx < len(_lists[pos]) and _lists[pos][idx] == value returnFalse
def__iter__(self): """Return an iterator over the sorted list.""" return (value for _listin self._lists for value in _list)
def__reversed__(self): """Return a reverse iterator over the sorted list.""" return (value for _listinreversed(self._lists) for value inreversed(_list))
def__repr__(self): """Return string representation of sorted list.""" return'SortedList({0})'.format(list(self))
defget_hash(self, l, r): res = self.h[r] - self.h[l] * self.p[r - l] return res % self.mod
二维差分
1 2 3 4 5 6 7 8 9 10 11 12 13 14
for i inrange(1, n + 1): for j inrange(1, m + 1): b[i][j] = a[i][j] - a[i - 1][j] - a[i][j - 1] + a[i - 1][j - 1] for i inrange(q): x1, y1, x2, y2, c = mint() b[x1][y1] += c b[x1][y2 + 1] -= c b[x2 + 1][y1] -= c b[x2 + 1][y2 + 1] += c for i inrange(1, n + 1): for j inrange(1, m + 1): a[i][j] = b[i][j] + a[i - 1][j] + a[i][j - 1] - a[i - 1][j - 1] print(a[i][j], end = ' ') print()
图论
spfa
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
defspfa(u, n, g) -> int: q = deque() q.append(u) vis = set() dist = defaultdict(lambda : inf) vis.add(u) dist[u] = 0 while q: t = q.popleft() vis.remove(t) for j, d in g[t]: if dist[j] > dist[t] + d: dist[j] = dist[t] + d if j notin vis: vis.add(j) q.append(j) return dist[n]
dijkstra
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
defdij(u, n, g) -> int: q = [(0, u)] # 距离 顶点 vis = set() dist = defaultdict(lambda : inf) dist[1] = 0 while q: d, u = heappop(q) if u in vis: continue vis.add(u) for j, d in g[u]: if j notin vis and dist[j] > dist[u] + d: dist[j] = dist[u] + d heappush(q, (dist[j], j)) return dist[n]
krushkal
1 2 3 4 5 6 7 8 9 10
defkruskal(): dsu = DSU(n) edges.sort(key = lambda x : x[2]) res = 0 for u, v, w in edges: if dsu.same(u, v): continue dsu.merge(u, v) res += w return res if dsu.n == 1else inf