C++ Bitset algorithm



给定一个填充为1或0的nxn网格。我想数一下有多少个子网格的角砖都是15。我的解决方案遍历所有行对并计算匹配1的数量,然后使用公式nummof1s * (nummof1s -1)/2并将结果相加。然而,当我在https://cses.fi/problemset/task/2137上提交我的解决方案时,n = 3000的输入没有输出(可能是由某些错误引起的)。错误是什么呢?

int main()
{

int n; cin>> n;
vector<bitset<3000>> grid(n);
for(int i=0;i<n;i++){
cin >> grid[i];
}
long result = 0;
for(int i=0;i<n-1;i++){
for(int j=i+1;j<n;j++){
int count = (grid[i]&grid[j]).count();
result += (count*(count-1))/2;
}
}
cout << result;
}

此解决方案将导致超过时间限制。在最坏情况下,bitset::count()为0 (n)。代码的总复杂度是O(n^3)。在最坏的情况下,操作次数为3000^3>10^10太大了

我不确定这个解决方案是您能想到的最好的解决方案,但它是基于原始解决方案的,为bitset提供了一个自制替代方案。这允许我使用64位块,并使用快速的popcnt()。硬件版本会更好,因为它可以与AVX寄存器一起工作,但这应该更便携,它可以在cses.fi上工作。基本上,count_common()函数不是生成一个长交集bitset,然后计算1的数量,而是生成交集的一个片段,并立即使用它来计算1。

流提取器可能会得到改进,从而节省更多的时间。

#include <iostream>
#include <array>
#include <cstdint>
#include <climits>

uint64_t popcnt(uint64_t v) {
v = v - ((v >> 1) & (uint64_t)~(uint64_t)0 / 3);
v = (v & (uint64_t)~(uint64_t)0 / 15 * 3) + ((v >> 2) & (uint64_t)~(uint64_t)0 / 15 * 3);
v = (v + (v >> 4)) & (uint64_t)~(uint64_t)0 / 255 * 15;
uint64_t c = (uint64_t)(v * ((uint64_t)~(uint64_t)0 / 255)) >> (sizeof(uint64_t) - 1) * CHAR_BIT;
return c;
}

struct line {
uint64_t cells_[47] = { 0 }; // 3000/64 = 47

uint64_t& operator[](int pos) { return cells_[pos]; }
const uint64_t& operator[](int pos) const { return cells_[pos]; }
};

uint64_t count_common(const line& a, const line& b) {
uint64_t u = 0;
for (int i = 0; i < 47; ++i) {
u += popcnt(a[i] & b[i]);
}
return u;
}

std::istream& operator>>(std::istream& is, line& ln) {
is >> std::ws;
int pos = 0;
uint64_t val = 0;
while (true) {
char ch = is.get();
if (is && ch == 'n') {
break;
}
if (ch == '1') {
val |= 1LL << (63 - pos % 64);
}
if ((pos + 1) % 64 == 0) {
ln[pos / 64] = val;
val = 0;
}
++pos;
}
if (pos % 64 != 0) {
ln[pos / 64] = val;
}
return is;
}

struct grid {
int n_;
std::array<line, 3000> data_;

line& operator[](int r) {
return data_[r];
}
};

std::istream& operator>>(std::istream& is, grid& g) {
is >> g.n_;
for (int r = 0; r < g.n_; ++r) {
is >> g[r];
}
return is;
}

int main()
{
grid g;
std::cin >> g;

uint64_t count = 0;
for (int r1 = 0; r1 < g.n_; ++r1) {
for (int r2 = r1 + 1; r2 < g.n_; ++r2) {
uint64_t n = count_common(g[r1], g[r2]);
count += n * (n - 1) / 2;
}
}
std::cout << count << 'n';
return 0;
}

最新更新