image.png

문제 태그

아이디어

  1. 접미사 배열 + LCP를 활용한다.

  2. 접미사 배열문자열 S의 모든 접미사 S[i..n-1] (총 n개)를 사전순으로 정렬해 인덱스 배열 SA를 만든다.

  3. LCP 배열인접한 접미사 SA[i-1]와 SA[i]의 최장 공통 접두사 길이 LCP[i]를 계산한다.

  4. 서로 다른 부분 문자열 수 계산

    $$ 답 = \sum_{i=0}^{n-1}(n - SA[i]) - \sum_{i=1}^{n-1} LCP[i] $$

정답

import sys
def build_sa(s):
    n = len(s)
    k = 1
    # 초기 랭크: 문자 코드
    rank = [ord(c) for c in s]
    sa = list(range(n))
    tmp = [0] * n

    while True:
        # (rank[i], rank[i+k]) 튜플로 정렬
        sa.sort(key=lambda i: (rank[i], rank[i + k] if i + k < n else -1))
        tmp[sa[0]] = 0
        for i in range(1, n):
            prev, curr = sa[i - 1], sa[i]
            prev_key = (rank[prev], rank[prev + k] if prev + k < n else -1)
            curr_key = (rank[curr], rank[curr + k] if curr + k < n else -1)
            tmp[curr] = tmp[prev] + (prev_key < curr_key)
        rank, tmp = tmp, rank
        if rank[sa[-1]] == n - 1:
            break
        k <<= 1
    return sa

def build_lcp(s, sa):
    n = len(s)
    rank = [0] * n
    for i, si in enumerate(sa):
        rank[si] = i
    h = 0
    lcp = [0] * n
    for i in range(n):
        if rank[i] == 0:
            continue
        j = sa[rank[i] - 1]
        while i + h < n and j + h < n and s[i + h] == s[j + h]:
            h += 1
        lcp[rank[i]] = h
        if h:
            h -= 1
    return lcp

def count_distinct_substrings(s):
    n = len(s)
    sa = build_sa(s)
    lcp = build_lcp(s, sa)
    total = 0
    for i in range(n):
        total += n - sa[i]       # 이 접미사에서 나올 수 있는 부분 문자열
        if i > 0:
            total -= lcp[i]      # 앞 접미사와 중복되는 부분
    return total

s = sys.stdin.readline().strip()
print(count_distinct_substrings(s))