區間第K大數——划分樹(POJ2104解題報告)


百度百科:划分樹是一種基於線段樹的數據結構。主要用於快速求出(在log(n)的時間復雜度內)序列區間的第k大值。

划分樹的基本思想就是對於某個區間,把它划分成兩個子區間,左邊區間的數小於右邊區間的數。查找的時候通過記錄進入左子樹的數的個數,確定下一個查找區間,最后范圍縮小到1,就找到了。

建樹

划分樹建樹的時間復雜度和線段樹不同,划分樹是O(nlogn),划分樹的建樹依賴一個排好序的數組來輔助建樹。

先將讀入的數據排序,然后,對於l,r區間,取他們的中間值sorted[mid],然后依次掃l~r區間,將每個節點划分到兒子(即l~mid和mid+1~r)中。注意,這里面分到每個子樹的節點是相對有序的,即對於分到每一顆子樹里面的數,不改變它們以前的相對順序。另外,在進行這個過程的時候,記錄一個類似前綴和的東西,即l到i這個區間內有多少節點划分到左子樹。

畫了一顆划分樹對數列[1 5 2 3 6 4 7 3 0 0]進行划分,下圖有助於理解(紅色表示該數被分到左兒子)//摘自小HH's Blog

划分樹

下面給出建樹的具體實現代碼並具體解釋(這里只是個人的理解),首先給出划分樹的數據結構

struct Node
{
int l,r;
int mid()
{
return (l+r)>>1;
}
} tree[N<<2];

int sorted[N];
int val[20][N],toLeft[20][N];
sorted[]是對輸入的數據進行排序后存放的數組,目的是用來輔助建樹。

val[][]有人會奇怪這里為什么用二維數組,第一維是用來記錄到第幾層,第二維是用來記錄這層上的所有數

toLeft[][]是用來記錄區間[tree[rt].l , l - 1]有多少個數划分到了左邊

如果還是不明白具體看下面代碼

void build(int l,int r,int rt,int deep)//復雜度nlogn
{
tree[rt].l = l;//記錄每個節點的左右兩個端點
tree[rt].r = r;
if(l == r) return;
int m = tree[rt].mid();
int midval = sorted[m];//取出每條線段上所有數的中值
int leftsame = m - l + 1;//表示在該條線段的左半部分上的數有多少和midval相等的數,先假定從[l,m]的數都和midval相等
for(int i = l ; i <= r ; i ++)//這個循環是用來求出leftsame的具體值
{
if(val[deep][i] < midval)
--leftsame;
}
int lpos = l,rpos = m + 1;//這里是用來求出下一層val[deep+1][]的值,也就是划分的過程
for(int i = l ; i <= r ; i++)//這里注意這時處理該條線段上的所有數值
{
if(i == l) toLeft[deep][i] = 0;//toLeft[][i]表示[tree[rt].l,i-1]有多少數在左部
else toLeft[deep][i] = toLeft[deep][i-1];//這里相當於對toLeft[][]初始化
if(val[deep][i] < midval)//如果小於midval 對toLeft[deep][i] ++
{
++toLeft[deep][i];
val[deep+1][lpos++] = val[deep][i];//把該值放到下一層的左邊
}
else if(val[deep][i] > midval)
{
val[deep+1][rpos++] = val[deep][i];
}
else//判斷和midval相等的數是放在左部還是右部
{
if(leftsame >= 0)//這里表示只能放置leftsame個數,多余的都要放到右子樹上去
{
--leftsame;
++toLeft[deep][i];
val[deep+1][lpos++] = val[deep][i];
}
else//放到下層的右子樹上
{
val[deep+1][rpos++] = val[deep][i];
}
}
}
build(l,m,rt<<1,deep+1);
build(m+1,r,rt<<1|1,deep+1);
}
查詢
假定查詢區間是[l,r]查詢第k大的數,那么肯定是當l == r的時候返回val[][l]。那么其他情況該怎么處理呢?

設定當前區間在線段[s,t]上,這時有區間[s,l-1]有toLeft[][l-1]個數進入下一層的左子樹,區間[s,r]有toLeft[][r]個數進入下層的左子樹,這時我們能夠求出在[l,r]區間有,sum = toLeft[][r] - toLeft[][l-1]個數進入下一層的左子樹,那么如果sum>=k則遞歸到左子樹查詢,否則遞歸到右子樹。到這里應該都很容易理解。

難點就是現在知道了應該遞歸到左右子樹,那么遞歸的區間呢?

首先,遞歸到左子樹,那么現在這條線段上的數全部都是上一條線段的數應該進入左子樹的,因此,這條線段的左邊是上個線段[s,l-1]區間里的toLeft[][l-1]個數,緊接着的就是sum個數(toLeft[][r] - toLeft[][l-1]),所以我們能夠得到新的查詢區間應該是[s+toLeft[][l-1],s+sum-1],這里-1是為了處理邊界問題,值得大家認真思索。

同理,遞歸到右子樹,對於現在這條線段上的數全部都是上一條線段的數應該進入右子樹的,這條線段的左邊是上條線段[s,l-1]區間里的 lsum = l - 1 - toLeft[][l-1] + 1個數,緊接着就是rsum =  r - st  - sum個數,所以查詢區間應該是[mid + 1 + lsum, mid + 1 + rsum]。

下面給出query代碼

int query(int l,int r,int k,int rt,int deep)
{
if(l == r) return val[deep][l];
//下面就是要確認新的查找區間
int s;//表示[l,r]里在左邊的數的個數
int ss;//表示[tree[rt].l,l-1]里在左邊的數的個數
if(l == tree[rt].l)
{
s = toLeft[deep][r];
ss = 0;
}
else
{
ss = toLeft[deep][l-1];
s = toLeft[deep][r] - ss;
}
//注意這里的在左邊的數都是和sorted[m]相比的,由此可以得到如果s>=k就去左子樹找,相反則去右子樹
if(s >= k)
{
/*進入左子樹,該條線段的左邊有ss個數及從是從上面[tree[rt].l,l-1]該進入左子樹的數繼承而來
*接着還應該有toLeft[deep][r] - toLeft[deep][l-1]個數即s個數
*所以可以確定新的查找區間應該是[tree[rt].l+ss,newl+s - 1]
*/
int newl = tree[rt].l + ss;
int newr = newl + s - 1;//這里減1是為了處理邊界問題
return query(newl,newr,k,rt<<1,deep+1);
}
else
{
/*
*進入右子樹,該條線段的左邊應該是上條線段[tree[rt].l,l-1]應該進入右子樹的數,即bb = l - tree[rt].l - ss個數
*接着還應該有上條線段[l,r]應該進入右子樹的數,即b = r - l + 1 - s個數
*所以可以確定新的查詢區間應該是[mid + 1 + bb,mid + 1 + bb + b - 1],這里的-1同一是為了處理邊界問題
*/
int m = tree[rt].mid();
int b = r - l + 1 - s;//表示[l,r]在右邊的數的個數
int bb = (l - 1) - tree[rt].l + 1 - ss;//表示[tree[rt].l,l-1]在右邊的數的個數
int newl = m + 1 + bb;
int newr = m + b + bb;//m + r - l + 1 - toLeft[deep][r] + ss - l - tree[rt].l - ss = m+r-
return query(newl,newr,k-s,rt<<1|1,deep+1);
}
}


至此,划分樹介紹完畢,代碼中的rt表示根節點,deep表示深度(也就是遞歸到第幾層)。

最后給出完整的C++源碼

const int N = 100005;
struct Node
{
int l,r;
int mid()
{
return (l+r)>>1;
}
} tree[N<<2];

int sorted[N];
int val[20][N],toLeft[20][N];

void build(int l,int r,int rt,int deep)
{
tree[rt].l = l;
tree[rt].r = r;
if(l == r) return;
int m = tree[rt].mid();
int midval = sorted[m];
int leftsame = m - l + 1;//表示在左子樹上有多少和midval相等的數
for(int i = l ; i <= r ; i ++)
{
if(val[deep][i] < midval)
--leftsame;
}
int lpos = l,rpos = m + 1;
for(int i = l ; i <= r ; i++)
{
if(i == l) toLeft[deep][i] = 0;//toLeft[][i]表示[tree[rt].l,i-1]有多少數在左部
else toLeft[deep][i] = toLeft[deep][i-1];//這里相當於對toLeft[][]初始化
if(val[deep][i] < midval)
{
++toLeft[deep][i];
val[deep+1][lpos++] = val[deep][i];
}
else if(val[deep][i] > midval)
{
val[deep+1][rpos++] = val[deep][i];
}
else//判斷和midval相等的數是放在左部還是右部
{
if(leftsame >= 0)
{
--leftsame;
++toLeft[deep][i];
val[deep+1][lpos++] = val[deep][i];
}
else
{
val[deep+1][rpos++] = val[deep][i];
}
}
}
build(l,m,rt<<1,deep+1);
build(m+1,r,rt<<1|1,deep+1);
}

int query(int l,int r,int k,int rt,int deep)
{
if(l == r) return val[deep][l];
//下面就是要確認新的查找區間
int s;//表示[l,r]里在左邊的數的個數
int ss;//表示[tree[rt].l,l-1]里在左邊的數的個數
if(l == tree[rt].l)
{
s = toLeft[deep][r];
ss = 0;
}
else
{
ss = toLeft[deep][l-1];
s = toLeft[deep][r] - ss;
}
//注意這里的在左邊的數都是和sorted[m]相比的,由此可以得到如果s>=k就去左子樹找,相反則去右子樹
if(s >= k)
{
/*進入左子樹,該條線段的左邊有ss個數及從是從上面[tree[rt].l,l-1]該進入左子樹的數繼承而來
*接着還應該有toLeft[deep][r] - toLeft[deep][l-1]個數即s個數
*所以可以確定新的查找區間應該是[tree[rt].l+ss,newl+s - 1]
*/
int newl = tree[rt].l + ss;
int newr = newl + s - 1;//這里減1是為了處理邊界問題
return query(newl,newr,k,rt<<1,deep+1);
}
else
{
/*
*進入右子樹,該條線段的左邊應該是上條線段[tree[rt].l,l-1]應該進入右子樹的數,即bb = l - tree[rt].l - ss個數
*接着還應該有上條線段[l,r]應該進入右子樹的數,即b = r - l + 1 - s個數
*所以可以確定新的查詢區間應該是[mid + 1 + bb,mid + 1 + bb + b - 1],這里的-1同一是為了處理邊界問題
*/
int m = tree[rt].mid();
int b = r - l + 1 - s;//表示[l,r]在右邊的數的個數
int bb = (l - 1) - tree[rt].l + 1 - ss;//表示[tree[rt].l,l-1]在右邊的數的個數
int newl = m + 1 + bb;
int newr = m + b + bb;//m + r - l + 1 - toLeft[deep][r] + ss - l - tree[rt].l - ss = m+r-
return query(newl,newr,k-s,rt<<1|1,deep+1);
}
}

static inline int Rint()//這段是整型數的輸入外掛,可以忽略不用看
{
struct X
{
int dig[256];
X()
{
for(int i = '0'; i <= '9'; ++i) dig[i] = 1;
dig['-'] = 1;
}
};
static X fuck;
int s = 1, v = 0, c;
for (; !fuck.dig[c = getchar()];);
if (c == '-') s = 0;
else if (fuck.dig[c]) v = c ^ 48;
for (; fuck.dig[c = getchar()]; v = v * 10 + (c ^ 48));
return s ? v : -v;
}


int main()
{
int n,m;
while(~scanf("%d %d",&n,&m))
{
for(int i = 1 ; i <= n ; i++)
{
scanf("%d",&val[0][i]);
sorted[i] = val[0][i];
}
sort(sorted+1,sorted+n+1);
build(1,n,1,0);
while(m--)
{
int a,b,c;
scanf("%d %d %d",&a,&b,&c);
printf("%d\n",query(a,b,c,1,0));
}
}
return 0;
}





注意!

本站转载的文章为个人学习借鉴使用,本站对版权不负任何法律责任。如果侵犯了您的隐私权益,请联系我们删除。



 
粤ICP备14056181号  © 2014-2021 ITdaan.com