跳转至

2022-06-19-Trie&KMP

课件:TG_C3_03_Trie_KMP-v2022.pdf
| | | | |
| ---- | ------------------------------------ | ----- | ---------------------------------------------------- |
| | HDU4825 Xor Sum(01 Trie) | 提高 | https://acm.hdu.edu.cn/showproblem.php?pid=4825 |
| | Remember the Word, UVa1401 | 提高+ | https://www.luogu.com.cn/problem/UVA1401 |
| | [TJOI2010]阅读理解 | 提高+ | https://www.luogu.com.cn/problem/P3879 |
| | [USACO08DEC G]Secret Message | 提高+ | https://www.luogu.com.cn/problem/P2922 |
| | [BOI2009]Radio Transmission 无线传输 | 普及 | https://www.luogu.com.cn/problem/P4391 |
| | Period, SEERC 2004, POJ1961 | 普及 | http://bailian.openjudge.cn/practice/1961?lang=en_US |

Period

/**
 * @file poj/1961/main
 * @brief 给定一个长度为 n 的字符串 s,求它每个前缀的最短循环节
 * @see
 * @author Ruiming Guo (guoruiming@stu.scu.edu.cn)
 * @copyright 2022
 * @date 2022/6/18 21:32:18
 **/

#include <iostream>
#include <vector>
#define rep(i, a, b) for (int i = (a); i < (int)(b); ++i)
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pi;
const int INF = 0x3f3f3f3f;
const ll LLINF = 0x3f3f3f3f3f3f3f3f;
const int N = 1000000 + 10;
char P[N];
int f[N];
int main() {
  // High rating and good luck!
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  int n, kase = 0;
  while (cin >> n && n) {
    cin >> P;
    f[0] = f[1] = 0;  // 递推初始值
    for (int i = 1; i < n; ++i) {
      int j = f[i];
      while (j && P[i] != P[j]) j = f[j];     // KMP 预处理 next 数组
      f[i + 1] = (P[i] == P[j] ? j + 1 : 0);  // TODO: 这一句什么意思?
    }
    // f[] 中保存了前 i 个字符组成的的循环节的长度
    cout << "Test case #" << ++kase << '\n';
    for (int i = 2; i <= n; ++i)
      if (f[i] > 0 && i % (i - f[i]) == 0)  // 能够表示成前缀的重复
        cout << i << ' ' << i / (i - f[i]) << '\n';
    cout << '\n';
  }
  return 0;
}

Radio Transmission

/**
 * @file luogu/4391/main
 * @brief
 * @see
 * @author Ruiming Guo (guoruiming@stu.scu.edu.cn)
 * @copyright 2022
 * @date 2022/6/18 21:48:41
 **/

#include <iostream>
#include <string>
#include <vector>
#define rep(i, a, b) for (int i = (a); i < (int)(b); ++i)
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pi;
const int INF = 0x3f3f3f3f;
const ll LLINF = 0x3f3f3f3f3f3f3f3f;
const int N = 1e6 + 4;
int f[N], n;
string s;
void getFail() {
  f[0] = f[1] = 0;
  for (int i = 1; i < n; ++i) {
    int j = f[i];
    while (j && s[i] != s[j]) j = f[j];
    f[i + 1] = (s[i] == s[j] ? j + 1 : 0);
  }
}
int main() {
  // High rating and good luck!
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  cin >> n >> s;
  getFail();
  cout << n - f[n] << endl;
  return 0;
}

Xor Sum

/**
 * @file hdu/4825/main
 * @brief 给定 N 个正整数的集合,之后发起 M 次询问,每次询问包含一个
 * 正整数 S,请从集合中找出一个正整数 K,使得 xor(K, S) 最大
 *
 * 构建一个 01-Trie 树,每次选与当前位不同的边(使 xor 结果为 1)
 * @see
 * @author Ruiming Guo (guoruiming@stu.scu.edu.cn)
 * @copyright 2022
 * @date 2022/6/18 21:56:30
 **/

#include <iostream>
#include <vector>
#define rep(i, a, b) for (int i = (a); i < (int)(b); ++i)
using namespace std;
typedef long long ll;
typedef vector<int> vi;
typedef pair<int, int> pi;
const int INF = 0x3f3f3f3f;
const ll LLINF = 0x3f3f3f3f3f3f3f3f;
const int N = 1e5 + 10, WL = 32;
struct Trie01 {
  int sz, ch[WL * N][2];
  ll val[WL * N];
  int newNode() {
    ch[sz][0] = ch[sz][1] = 0, val[sz] = 0;
    return sz++;
  }
  void init() {
    sz = 0;
    newNode();
  }
  void insert(ll x) {
    int u = 0;
    for (int i = WL; i >= 0; --i) {
      int v = (x >> i) & 1, &c = ch[u][v];
      if (!c) c = newNode();
      u = c;
    }
    val[u] = x;
  }
  ll query(ll x) {
    int u = 0;
    for (int i = WL; i >= 0; --i) {
      int v = (x >> i) & 1;
      u = (ch[u][v ^ 1] ? ch[u][v ^ 1] : ch[u][v]);
    }
    return val[u];
  }
};
Trie01 trie;
int main() {
  // High rating and good luck!
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
  int t;
  cin >> t;
  for (int n, m, kase = 1; cin >> n >> m && kase <= t; ++kase) {
    ll x;
    trie.init();
    while (n--) cin >> x, trie.insert(x);
    printf("Case #%d:\n", kase);
    while (m--) cin >> x, printf("%lld\n", trie.query(x));
  }
  return 0;
}

阅读理解

/**
 * @file luogu/3879/main
 * @brief
 * @see
 * @author Ruiming Guo (guoruiming@stu.scu.edu.cn)
 * @copyright 2022
 * @date 2022/6/21 20:45:50
 **/

#include <bitset>
#include <iostream>
using namespace std;
int n, m;
char s[1010];
int l;
int tot = 0;
int trie[300007][26];
vector<short> v[300007];
void insert(char *s, int x) {
  int rt = 0;
  for (int i = 0; s[i]; ++i) {
    int v = s[i] - 'a';
    if (!trie[rt][v]) trie[rt][v] = ++tot;
    rt = trie[rt][v];
  }
  b[rt][x] = 1;
}
void query(char *s) {
  int rt = 0;
  for (int i = 0; s[i]; ++i) {
    int v = s[i] - 'a';
    if (!trie[rt][v]) {
      cout << ' ' << endl;
      return;
    }
    rt = trie[rt][v];
  }
  for (int i = 1; i <= n; ++i) {
    if (b[rt][i] == 1) cout << i << ' ';
  }
  cout << endl;
}
int main() {
  cin >> n;
  for (int i = 1; i <= n; ++i) {
    cin >> l;
    for (int j = 1; j <= l; ++j) {
      cin >> s;
      insert(s, i);
    }
  }
  cin >> m;
  for (int i = 1; i <= m; ++i) {
    cin >> s;
    query(s);
  }
  return 0;
}

Remembering the words

思路:Trie + DP

给定一个字符串 \(s\),长度为 \(n\)

状态:\(f[i]\) 表示 \(s[i:n]\) 的的串

转移:\(f[i] = \sum f[i + x.length]\),其中字符串 \(x\)\(s[i:n]\) 的前缀。很显然如果 \(x\) 不是单词则 \(f[x]=0\)

因为最多可能有 \(4000\) 个单词,时间复杂度无法承受;

可以将单词全部插入 Trie 中,在其中查找 \(s[i:n]\),每经过一个单词节点就能找到转移方程中的一个 \(x\),最多只需 \(100\) 次查找就可以找到所有 \(x\)

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 4000 * 100 + 10, SIGMA = 26;
// 字母表为全体小写字母的Trie
struct Trie {
  int ch[MAXN][SIGMA], val[MAXN], sz;                // 结点总数
  void clear() { sz = 1, fill_n(ch[0], SIGMA, 0); }  // 初始时只有一个根结点
  int idx(char c) { return c - 'a'; }                // 字符c的编号
  // 插入字符串s,附加信息为v。注意v必须非0,因为0代表“本结点不是单词结点”
  void insert(const char *s, int v) {
    int u = 0, n = strlen(s);
    for (int i = 0; i < n; i++) {
      int c = idx(s[i]);
      if (!ch[u][c])  // 结点不存在, 新建
        fill_n(ch[sz], SIGMA, 0), val[sz] = 0, ch[u][c] = sz++;
      u = ch[u][c];  // 往下走
    }
    val[u] = v;  // 字符串的最后一个字符的附加信息为v
  }
  // 查找s的长度不超过len的前缀
  void find_prefixes(const char *s, int len, vector<int> &ans) {
    int u = 0;
    for (int i = 0; i < len; i++) {
      if (s[i] == '\0') break;
      int c = idx(s[i]);
      if (!ch[u][c]) break;
      u = ch[u][c];
      if (val[u] != 0) ans.push_back(val[u]);  // 找到一个前缀
    }
  }
};

// 文本串最大长度, 单词最大个数
const int TL = 3e5 + 4, WC = 4000 + 4, MOD = 20071027;
int D[TL], WLen[WC];
Trie trie;
int main() {
  ios::sync_with_stdio(false), cin.tie(0);
  string text, word;
  for (int kase = 1, S; cin >> text >> S; kase++) {
    trie.clear();
    for (int i = 1; i <= S; i++)
      cin >> word, WLen[i] = word.length(), trie.insert(word.c_str(), i);
    int L = text.length();
    fill_n(D, L, 0), D[L] = 1;
    for (int i = L - 1; i >= 0; i--) {
      vector<int> p;
      trie.find_prefixes(text.c_str() + i, L - i, p);
      for (size_t j = 0; j < p.size(); j++) (D[i] += D[i + WLen[p[j]]]) %= MOD;
    }
    printf("Case %d: %d\n", kase, D[0]);
  }
  return 0;
}

[USACO08DEC]Secret Message G

截获了 \(M\) 条二进制信息的前缀,知道第 \(i\) 条信息的前缀为 \(b_i\);同时知道有 \(N\) 条暗号,知道第 \(j\) 条暗号的前缀为 \(c_j\)。问对于每条暗号,有多少信息可能是这条暗号。

也就是说,

  • 情形一:若 \(|b_i| < |c_j|\)\(b_i\)\(c_j\) 的一个前缀
  • 情形二:若 \(|b_i| = |c_j|\)\(b_i = c_j\)
  • 情形三:若 \(|b_i| > |c_j|\)\(c_j\)\(b_i\) 的一个前缀

  • \(M\) 条信息插入 Trie,在路过的点上标记一个 \(sum\),在结尾的点上标记一个 \(end\)

  • 对于每个暗号,沿着根往下走,沿途把每个点的 \(end\) 加入 \(ans\)
  • 如果走到哪个点失配,每个点都是情形一,没有情形二三,输出记录 \(ans\) 即可
  • 如果完成匹配,还要考虑情形二三。对于情形二,在 \(ans\) 里加上这个点的 \(end\) 时已经考虑了;对于情形三,需要统计会继续走下去的单词的个数,有 \(sum - end\) 个(只是路过,并不停留)。
/**
 * @file luogu/2922/main.cpp
 * @brief
 * @see
 * @author Ruiming Guo (guoruiming@stu.scu.edu.cn)
 * @copyright 2022
 * @date 2022/6/22 09:26:13
 **/

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
struct msg {
  int ch[2], sum, end;
} T[500001];
int A[10001];
int main() {
  // High rating and good luck!
  int m, n, sz = 0;
  cin >> m >> n;
  for (int i = 1, len; i <= m; ++i) {
    cin >> len;
    int u = 0;
    for (int j = 1, b; j <= len; ++j) {
      cin >> b;
      int &v = T[u].ch[b];
      if (v == 0) v = ++sz;
      u = v, T[u].sum++;  // 经过此节点的信号增加了
    }
    T[u].end++;  // 在此节点结束的信号增加
  }
  for (int i = 1, len; i <= n; ++i) {
    cin >> len;
    for (int j = 1; j <= len; ++j) cin >> A[j];
    int u = 0, ans = 0;
    bool match = true;
    for (int j = 1; j <= len; ++j) {
      if (T[u].ch[A[j]] != 0)
        u = T[u].ch[A[j]];
      else {
        match = false;
        break;
      }
      ans += T[u].end;  // 情形一
    }
    if (match) ans += T[u].sum - T[u].end;  // 情形三
    cout << ans << endl;
  }
  return 0;
}