我有一个字符串S = "&|&&|&&&|&"
我们应该在其中获取字符串的 2 个索引之间的'&'
数。 因此,此处1
和8
2 个索引的输出应为 5。这是我的蛮力风格代码:
std::size_t cnt = 0;
for(i = start; i < end; i++) {
if (S[i] == '&')
cnt++;
}
cout << cnt << endl;
我面临的问题是我的代码在编码平台中因更大的输入而超时。谁能提出一种更好的方法来降低这里的时间复杂度?
我决定尝试几种方法,包括这个问题的其他两个答案提出的方法。我对输入做了几个假设,目的是为单个大字符串找到一个快速实现,该字符串只会搜索一次以查找单个字符。对于将针对多个字符进行多个查询的字符串,我建议按照用户Jefferson Rondan的评论中的建议构建一个段树。
我用std::chrono::steady_clock::now()
来衡量实施时间。
假设
- 程序提示用户输入字符串大小、搜索字符、开始索引和结束索引。
- 输入格式良好(开始 <= 结束 <= 大小(。
- 该字符串是从
' '
和'~'
之间的 ascii 字符均匀分布随机生成的。 - 字符串对象中的基础数据连续存储在内存中。
方法
- 朴素 for 循环:索引变量递增,并使用索引逐个字符索引字符串。 迭代器
- 循环:使用字符串迭代器,在每次迭代时取消引用,并与搜索字符进行比较。
- 基础数据指针:找到指向字符串的基础字符数组的指针,并在循环中递增。取消引用的指针与搜索字符进行比较。
- 索引映射(如 GyuHyeon Choi 所建议的(:一个包含
max printable ascii character
元素的 int 类型数组初始化为 0,对于遍历数组时遇到的每个字符,相应的索引将递增 1。最后,取消引用搜索字符的索引,以查找找到的该字符数。 - 只需使用 std::count(如 Atul Sharma 建议的那样(:只需使用构建计数功能即可。
- 将基础数据重新转换为指向较大数据类型的指针并进行迭代:保存
string
数据的基础const char* const
指针被重新解释为指向更广泛的数据类型的指针(在本例中为指向类型uint64_t
的指针(。然后,每个取消引用的uint64_t都使用由搜索字符组成的掩码进行异或运算,并用0xff
掩码对uint64_t
的每个字节进行 XOR 处理。这减少了单步执行整个数组所需的指针增量数。
结果
对于从索引 5 到 999999995 搜索大小为 1,000,000,000 的字符串,每种方法的结果如下:
- 朴素循环:843 ms
- 迭代器循环:818 ms
- 基础数据指针:750 ms
- 索引映射(由GyuHyeon Choi建议(:929毫秒
- 只需使用 std::count(如 Atul Sharma 建议的那样(:819 ms
- 将基础数据重新转换为指向较大数据类型的指针并迭代:664 毫秒
讨论
性能最好的实现是我自己的数据指针重新转换,它完成的时间略高于朴素解决方案所花费的 75%。最快的"简单"解决方案是对底层数据结构进行指针迭代。此方法的优点是易于实现、理解和维护。索引映射方法尽管比天真的解决方案快 2 倍,但在我的基准测试中没有看到这样的加速。std::count
方法与手动指针迭代一样快,实现起来甚至更简单。如果速度确实很重要,请考虑重新转换基础指针。否则,请坚持使用std::count
。
《守则》
#include <algorithm>
#include <iostream>
#include <random>
#include <string>
#include <functional>
#include <typeinfo>
#include <chrono>
int main(int argc, char** argv)
{
std::random_device device;
std::mt19937 generator(device());
std::uniform_int_distribution<short> short_distribution(' ', '~');
auto next_short = std::bind(short_distribution, generator);
std::string random_string = "";
size_t string_size;
size_t start_search_index;
size_t end_search_index;
char search_char;
std::cout << "String size: ";
std::cin >> string_size;
std::cout << "Search char: ";
std::cin >> search_char;
std::cout << "Start search index: ";
std::cin >> start_search_index;
std::cout << "End search index: ";
std::cin >> end_search_index;
if (!(start_search_index <= end_search_index && end_search_index <= string_size))
{
std::cout << "Requires start_search <= end_search <= string_sizen";
return 0;
}
for (size_t i = 0; i < string_size; i++)
{
random_string += static_cast<char>(next_short());
}
// naive implementation
size_t count = 0;
auto start_time = std::chrono::steady_clock::now();
for (size_t i = start_search_index; i < end_search_index; i++)
{
if (random_string[i] == search_char)
count++;
}
auto end_time = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
std::cout << "Naive implementation. Found: " << count << "n";
std::cout << "Elapsed time: " << duration.count() << "us.nn";
// Iterator solution
count = 0;
start_time = std::chrono::steady_clock::now();
for (auto it = random_string.begin() + start_search_index, end = random_string.begin() + end_search_index;
it != end;
it++)
{
if (*it == search_char)
count++;
}
end_time = std::chrono::steady_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
std::cout << "Iterator solution. Found: " << count << "n";
std::cout << "Elapsed time: " << duration.count() << "us.nn";
// Iterate on data
count = 0;
start_time = std::chrono::steady_clock::now();
for (auto it = random_string.data() + start_search_index,
end = random_string.data() + end_search_index;
it != end; it++)
{
if (*it == search_char)
count++;
}
end_time = std::chrono::steady_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
std::cout << "Iterate on underlying data solution. Found: " << count << "n";
std::cout << "Elapsed time: " << duration.count() << "us.nn";
// use index mapping
count = 0;
size_t count_array['~']{ 0 };
start_time = std::chrono::steady_clock::now();
for (size_t i = start_search_index; i < end_search_index; i++)
{
count_array[random_string.at(i)]++;
}
end_time = std::chrono::steady_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
count = count_array[search_char];
std::cout << "Using index mapping. Found: " << count << "n";
std::cout << "Elapsed time: " << duration.count() << "us.nn";
// using std::count
count = 0;
start_time = std::chrono::steady_clock::now();
count = std::count(random_string.begin() + start_search_index
, random_string.begin() + end_search_index
, search_char);
end_time = std::chrono::steady_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
std::cout << "Using std::count. Found: " << count << "n";
std::cout << "Elapsed time: " << duration.count() << "us.nn";
// Iterate on larger type than underlying char
count = end_search_index - start_search_index;
start_time = std::chrono::steady_clock::now();
// Iterate through underlying data until the address is modulo 4
{
auto it = random_string.data() + start_search_index;
auto end = random_string.data() + end_search_index;
// iterate until we reach a pointer that is divisible by 8
for (; (reinterpret_cast<std::uintptr_t>(it) & 0x07) && it != end; it++)
{
if (*it != search_char)
count--;
}
// iterate on 8-byte sized chunks until we reach the last full chunk that is 8-byte aligned
auto chunk_it = reinterpret_cast<const uint64_t* const>(it);
auto chunk_end = reinterpret_cast<const uint64_t* const>((reinterpret_cast<std::uintptr_t>(end)) & ~0x07);
uint64_t search_xor_mask = 0;
for (size_t i = 0; i < 64; i+=8)
{
search_xor_mask |= (static_cast<uint64_t>(search_char) << i);
}
constexpr uint64_t all_ones = 0xff;
for (; chunk_it != chunk_end; chunk_it++)
{
auto chunk = (*chunk_it ^ search_xor_mask);
if (chunk & (all_ones << 56))
count--;
if (chunk & (all_ones << 48))
count--;
if (chunk & (all_ones << 40))
count--;
if (chunk & (all_ones << 32))
count--;
if (chunk & (all_ones << 24))
count--;
if (chunk & (all_ones << 16))
count--;
if (chunk & (all_ones << 8))
count--;
if (chunk & (all_ones << 0))
count--;
}
// iterate on the remainder of the bytes, should be no more than 7, tops
it = reinterpret_cast<decltype(it)>(chunk_it);
for (; it != end; it++)
{
if (*it != search_char)
count--;
}
}
end_time = std::chrono::steady_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time);
std::cout << "Iterate on underlying data with larger step sizes. Found: " << count << "n";
std::cout << "Elapsed time: " << duration.count() << "us.nn";
}
示例输出
String size: 1000000000
Search char: &
Start search index: 5
End search index: 999999995
Naive implementation. Found: 10527454
Elapsed time: 843179us.
Iterator solution. Found: 10527454
Elapsed time: 817762us.
Iterate on underlying data solution. Found: 10527454
Elapsed time: 749513us.
Using index mapping. Found: 10527454
Elapsed time: 928560us.
Using std::count. Found: 10527454
Elapsed time: 819412us.
Iterate on underlying data with larger step sizes. Found: 10527454
Elapsed time: 664338us.
int cnt[125]; // ASCII '&' = 46, '|' = 124
cnt['&'] = 0;
for(int i = start; i < end; i++) {
cnt[S.at(i)]++;
}
cout << cnt['&'] << endl;
if
比较和分支时很昂贵。所以会更好。
您可以使用算法标准C++库中的std::count
。 只需包含标题<algorithm>
std::string s{"&|&&|&&&|&"};
// https://en.cppreference.com/w/cpp/algorithm/count
auto const count = std::count(s.begin() + 1 // starting index
,s.begin() + 8 // one pass end index
,'&');