// 面试题40:最小的k个数 // 题目:输入n个整数,找出其中最小的k个数。例如输入4、5、1、6、2、7、3、8 // 这8个数字,则最小的4个数字是1、2、3、4。 #include <cstdio> #include "Array.h" #include <set> #include <vector> #include <iostream> #include <functional> using namespace std; // ====================方法1==================== // 基于随机快速排序算法, void GetLeastNumbers_Solution1(int* input, int n, int* output, int k) { if (input == nullptr || output == nullptr || n < k || n <= 0 || k <= 0) return; int start = 0; int end = n - 1; int index = Partition(input, n, start, end); while (index != k - 1) //只要随机排序的数字是第k个数字就跳出 { if (index > k - 1) { end = index - 1; index = Partition(input, n, start, end); } else { start = index + 1; index = Partition(input, n, start, end); //index指的是在input里的 } } for (int i = 0; i <= index; ++i) output[i] = input[i]; } // ====================方法2==================== typedef multiset<int, std::greater<int> > intSet; typedef multiset<int, std::greater<int> >::iterator setIterator; void GetLeastNumbers_Solution2(const vector<int>& data, intSet& leastNumbers, int k) { leastNumbers.clear(); if (k < 1 || data.size() < k) return; vector<int>::const_iterator iter = data.begin(); //数据迭代器 for (; iter != data.end(); ++iter) { if (leastNumbers.size() < k) //堆未满 leastNumbers.insert(*iter); else { setIterator iterGreatest = leastNumbers.begin(); //堆迭代器 if (*iter < *(leastNumbers.begin())) //查看当前数是否小于堆中最大数 { leastNumbers.erase(iterGreatest); //擦除最大数 leastNumbers.insert(*iter); //插入当前数 } } } }
// ====================测试代码==================== void Test(const char* testName, int* data, int n, int* expectedResult, int k) { if (testName != nullptr) printf("%s begins: \n", testName); vector<int> vectorData; for (int i = 0; i < n; ++i) vectorData.push_back(data[i]); if (expectedResult == nullptr) printf("The input is invalid, we don‘t expect any result.\n"); else { printf("Expected result: \n"); for (int i = 0; i < k; ++i) printf("%d\t", expectedResult[i]); printf("\n"); } printf("Result for solution1:\n"); int* output = new int[k]; GetLeastNumbers_Solution1(data, n, output, k); if (expectedResult != nullptr) { for (int i = 0; i < k; ++i) printf("%d\t", output[i]); printf("\n"); } delete[] output; printf("Result for solution2:\n"); intSet leastNumbers; GetLeastNumbers_Solution2(vectorData, leastNumbers, k); printf("The actual output numbers are:\n"); for (setIterator iter = leastNumbers.begin(); iter != leastNumbers.end(); ++iter) printf("%d\t", *iter); printf("\n\n"); } // k小于数组的长度 void Test1() { int data[] = { 4, 5, 1, 6, 2, 7, 3, 8 }; int expected[] = { 1, 2, 3, 4 }; Test("Test1", data, sizeof(data) / sizeof(int), expected, sizeof(expected) / sizeof(int)); } // k等于数组的长度 void Test2() { int data[] = { 4, 5, 1, 6, 2, 7, 3, 8 }; int expected[] = { 1, 2, 3, 4, 5, 6, 7, 8 }; Test("Test2", data, sizeof(data) / sizeof(int), expected, sizeof(expected) / sizeof(int)); } // k大于数组的长度 void Test3() { int data[] = { 4, 5, 1, 6, 2, 7, 3, 8 }; int* expected = nullptr; Test("Test3", data, sizeof(data) / sizeof(int), expected, 10); } // k等于1 void Test4() { int data[] = { 4, 5, 1, 6, 2, 7, 3, 8 }; int expected[] = { 1 }; Test("Test4", data, sizeof(data) / sizeof(int), expected, sizeof(expected) / sizeof(int)); } // k等于0 void Test5() { int data[] = { 4, 5, 1, 6, 2, 7, 3, 8 }; int* expected = nullptr; Test("Test5", data, sizeof(data) / sizeof(int), expected, 0); } // 数组中有相同的数字 void Test6() { int data[] = { 4, 5, 1, 6, 2, 7, 2, 8 }; int expected[] = { 1, 2 }; Test("Test6", data, sizeof(data) / sizeof(int), expected, sizeof(expected) / sizeof(int)); } // 输入空指针 void Test7() { int* expected = nullptr; Test("Test7", nullptr, 0, expected, 0); } int main(int argc, char* argv[]) { Test1(); Test2(); Test3(); Test4(); Test5(); Test6(); Test7(); return 0; }
分析:大师,我悟了。
时间复杂度O(n)。
class Solution { void Swap(vector<int> &input, int index1, int index2) { int temp = input[index1]; input[index1] = input[index2]; input[index2] = temp; } int Partition(vector<int> &input, int n, int start, int end) { int index = rand() % (end - start + 1) + start; Swap(input, index, end); int small = start - 1; for (index = start; index < end; ++index) { if (input[index] < input[end]) { ++ small; if (small != index) Swap(input, small, index); } } ++small; Swap(input, small, end); return index; } public: vector<int> GetLeastNumbers_Solution(vector<int> input, int k) { int n = (int)input.size(); vector<int> leastNumber; if (n < k || n <= 0 || k <= 0) return leastNumber; int start = 0; int end = n - 1; int index = Partition(input, n, start, end); while (index != k - 1) { if (index > k - 1) { end = index - 1; index = Partition(input, n, start, end); } else { start = index + 1; index = Partition(input, n, start, end); } } for (int i = 0; i <= index; ++i) leastNumber.push_back(input[i]); return leastNumber; } };
原文:https://www.cnblogs.com/ZSY-blog/p/12631074.html