DQS DQS

2017ACM-ICPC亚洲区(西安赛区)网络赛-A.Tree LCA|树剖|矩阵

in 矩乘,bitset,算法,图论,LCA,数论,STL,树链剖分read (172) 文章转载请注明来源!

题目链接


题意:给你n个点的树,每个点上有一个01矩阵(给定种子生成)。每次询问u到v路径上的矩阵顺次乘起来得到的矩阵(每个点模2,也就是说得到的也是01矩阵),通过公式计算输出对应数字(简化输出)。$n<=3000,Q<=30000$。时限9s。

简单地说就是给一棵树,每个点有一个矩阵,每次询问链上矩阵乘积。

算法一

很显然得到一个算法:剖一下,线段树维护从左到右乘起来的矩阵和从右到左乘起来的矩阵,查询的时候按顺序乘起来即可。
复杂度$O(Qlog^2(n)64^3)$,过不了。

可以用bitset或者unsigned long long优化矩阵乘法,复杂度降为$O(Qlog^2(n)64^2)$,还是过不了。并且实测ULL更优。

算法二

可以发现没有修改操作,所以可以不用线段树。对于一个点u,维护u到top[u]和top[u]到u的矩阵。最后u和v在同一条重链上时倍增上去。

算法三

这时候我突然发现,都用上倍增了为什么还要剖!于是这个题就只需要LCA即可。这时复杂度$O(log(n)64^2Q)$,可以过了。

矩乘的优化

如果用bitset,那么矩乘为c.r[i][j] = (a.r[i] & b.c[j]).count() & 1。如果用ULL,矩乘要用一个叫popcount的东西。这两个的常数都很大。
思考c[i][j]^=a[i][k] & b[k][j]。若固定i和k不动,则可理解为c[i][j]^=b[k][j] [a[i][k]==1]。就是当a[i][k]为1时,c的第i行异或上b的第k行。这样就不需要存贮列向量,还能少一个大常数,程序跑的就很快了。

代码贴一个稍短点的……之前zz写的剖就不贴了。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<map>
#include<queue>
#include<cmath>
#include<bitset>
using namespace std;

typedef unsigned long long ULL;
typedef long long LL;
const int SZ = 3010;
const int INF = 1000000010;
const LL mod = 19260817;

LL read()
{
    LL n = 0;
    char a = getchar();
    int flag = 0;
    while(a < '0' || a > '9') { if(a == '-') flag = 1; a = getchar(); }
    while(a >= '0' && a <= '9') n = n * 10 + a - '0',a = getchar();
    if(flag) n = -n;
    return n;
}

LL m19[70],m26[70];

struct matrix
{
    ULL r[65];
    matrix()
    {
        for(int i = 0;i < 64;i ++)
            r[i] = (1llu << i);
    }
    void print()
    {
        for(int i = 0;i < 64;i ++,puts(""))
            for(int j = 0;j < 64;j ++)
                cout << ((r[i] >> j) & 1);
    }
}g[SZ],up[SZ],down[SZ];

matrix operator *(const matrix &a,const matrix &b)
{
    matrix ans;
    for(int i = 0;i < 64;i ++) ans.r[i] = 0;
    for(int i = 0;i < 64;i ++)
        for(int j = 0;j < 64;j ++)
            if(a.r[i] >> j & 1)
                ans.r[i] ^= b.r[j];
    return ans;
}

LL get(const matrix &a)
{
    LL ans = 0;
    for(int i = 0;i < 64;i ++)
        for(int j = 0;j < 64;j ++)
            ans = (ans + ((a.r[i] >> j) & 1) * m19[i + 1] * m26[j + 1] % mod) % mod;
    return ans;
}

int n,m,head[SZ],nxt[SZ * 2],tot = 1;

struct edge
{
    int t;
}l[SZ * 2];

void build(int f,int t)
{
    l[++ tot] = (edge){t};
    nxt[tot] = head[f];
    head[f] = tot;
}

matrix U[SZ][15],D[SZ][15];
int anc[SZ][15],deep[SZ];
void dfs_lca(int u,int fa)
{
    deep[u] = deep[fa] + 1;
    anc[u][0] = fa;
    D[u][0] = U[u][0] = g[u];
    for(int i = 1;anc[u][i - 1];i ++)
    {
        anc[u][i] = anc[anc[u][i - 1]][i - 1];
        U[u][i] = U[u][i - 1] * U[anc[u][i - 1]][i - 1];
        D[u][i] = D[anc[u][i - 1]][i - 1] * D[u][i - 1];
    }
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(v == fa) continue;
        dfs_lca(v,u);
    }
}

int ask_lca(int u,int v)
{
    if(deep[u] < deep[v]) swap(u,v);
    if(deep[u] > deep[v])
    {
        int dd = deep[u] - deep[v];
        for(int i = 0;i <= 12;i ++)
            if(dd & (1 << i))
                u = anc[u][i];
    }
    if(u == v) return u;
    for(int i = 12;i >= 0;i --)
        if(anc[u][i] != anc[v][i])
            u = anc[u][i],v = anc[v][i];
    return anc[u][0];
}

matrix S[100];
matrix get_ans(int u,int v)
{
    matrix ans;
    if(deep[u] > deep[v])
    {
        int dd = deep[u] - deep[v];
        for(int i = 0;i <= 12;i ++)
            if(dd & (1 << i))
                ans = ans * U[u][i],u = anc[u][i];
    }
    else
    {
        int top = 0,dd = deep[v] - deep[u];
        for(int i = 0;i <= 12;i ++)
            if(dd & (1 << i))
                S[++ top] = D[v][i],v = anc[v][i];
        for(int i = top;i >= 1;i --) ans = ans * S[i];
    }
    return ans;
}

matrix ask_ans(int x,int y)
{
    int lca = ask_lca(x,y);
    return get_ans(x,lca) * g[lca] * get_ans(lca,y);
}


int main()
{
    m19[0] = m26[0] = 1;
    for(int i = 1;i <= 64;i ++)
        m19[i] = m19[i - 1] * 19 % mod,
        m26[i] = m26[i - 1] * 26 % mod;
    while(~scanf("%d%d",&n,&m))
    {
        tot = 1;
        memset(head,0,sizeof(head));
        memset(anc,0,sizeof(anc));
        for(int i = 1;i <= n - 1;i ++)
        {
            int u = read(),v = read();
            build(u,v); build(v,u);
        }
        ULL seed;
        scanf("%llu",&seed);
        for(int i = 1;i <= n;i ++)
            for(int j = 0;j < 64;j ++)
                g[i].r[j] = 0;

        for(int i = 1;i <= n;i ++)
            for(int p = 1;p <= 64;p ++)
            {
                seed ^= seed * seed + 15;
                for(int q = 1;q <= 64;q ++)
                {
                    ULL x = (seed >> (q - 1)) & 1;
                    g[i].r[p - 1] |= x << (q - 1);
                }
            }
        dfs_lca(1,0);
        while(m --)
        {
            int u = read(),v = read();
            printf("%lld\n",get(ask_ans(u,v)));
        }
    }
    return 0;
}
jrotty WeChat Pay

微信打赏

jrotty Alipay

支付宝打赏

文章二维码

扫描二维码,在手机上阅读!

矩乘bitset算法图论LCA数论STL树链剖分
最后由DQS修改于2017-09-19 23:19
发表新评论
博客已萌萌哒运行
© 2018 由 Typecho 强力驱动.Theme by Yodu
前篇 后篇
雷姆
拉姆