접미사 배열 + LCP를 활용한다.
접미사 배열문자열 S의 모든 접미사 S[i..n-1] (총 n개)를 사전순으로 정렬해 인덱스 배열 SA를 만든다.
LCP 배열인접한 접미사 SA[i-1]와 SA[i]의 최장 공통 접두사 길이 LCP[i]를 계산한다.
서로 다른 부분 문자열 수 계산
$$ 답 = \sum_{i=0}^{n-1}(n - SA[i]) - \sum_{i=1}^{n-1} LCP[i] $$
import sys
def build_sa(s):
n = len(s)
k = 1
# 초기 랭크: 문자 코드
rank = [ord(c) for c in s]
sa = list(range(n))
tmp = [0] * n
while True:
# (rank[i], rank[i+k]) 튜플로 정렬
sa.sort(key=lambda i: (rank[i], rank[i + k] if i + k < n else -1))
tmp[sa[0]] = 0
for i in range(1, n):
prev, curr = sa[i - 1], sa[i]
prev_key = (rank[prev], rank[prev + k] if prev + k < n else -1)
curr_key = (rank[curr], rank[curr + k] if curr + k < n else -1)
tmp[curr] = tmp[prev] + (prev_key < curr_key)
rank, tmp = tmp, rank
if rank[sa[-1]] == n - 1:
break
k <<= 1
return sa
def build_lcp(s, sa):
n = len(s)
rank = [0] * n
for i, si in enumerate(sa):
rank[si] = i
h = 0
lcp = [0] * n
for i in range(n):
if rank[i] == 0:
continue
j = sa[rank[i] - 1]
while i + h < n and j + h < n and s[i + h] == s[j + h]:
h += 1
lcp[rank[i]] = h
if h:
h -= 1
return lcp
def count_distinct_substrings(s):
n = len(s)
sa = build_sa(s)
lcp = build_lcp(s, sa)
total = 0
for i in range(n):
total += n - sa[i] # 이 접미사에서 나올 수 있는 부분 문자열
if i > 0:
total -= lcp[i] # 앞 접미사와 중복되는 부분
return total
s = sys.stdin.readline().strip()
print(count_distinct_substrings(s))