模拟退火

骗分大法好!

1. 定义

一种神仙的随机算法,在你找不到 DP 的方法时可以使用来骗分。比如我在 NOIP 考场上 T3 就骗了 96 pts(虽然差一点就想到了 DP 正解)。

2. 思想

首先,它本身是在一个很大的 $x$ 取值范围寻找函数的最优解(注意函数不一定是单峰或单调)。

下面默认寻找最小值。

算法就是我们找到一个初值,然后一直随机跳,找到最优解。

再定义一个温度 $t$,可以表示随机的范围。比如当前的区间就是 $[x_0-dt,x_0+dt]$,其中 $d$ 是常数。

$t$ 会不断衰减,定义一个衰减系数,每一次 $t$ 都会乘上衰减系数,一般很靠近 1。

假设我们当前是 $x_0$,随机到的点是 $x$,再假设 $f(x)-f(x_0)=\Delta$。

  1. $\Delta<0$,则 $x_0$ 不优,我们直接跳到 $x$ 即可。
  2. $\Delta\geq0$,则 $x$ 不优。注意函数可能不止一个峰值,所以有可能 $x$ 的位置更靠近最小值,我们就需要有一定的概率跳过去。$t$ 越小,概率越小,$\Delta$ 越大,概率越小。

对于第二种情况,这个概率各家都不同,一般来看,可以使用 $e^{-\Delta kt}$ 的概率。其中 $k$ 是常数,需要自己定义。

很明显,$0<e^{-\Delta kt}<1$。

好,上面是理论,下面才是重点。(骗分嘛,怎么好怎么来)

3. 实践与经验

诚然,一次的答案确实可能是局部的最优解而不是全局的。

于是,我们就多做几次,每一次都这么走,那么肯定走到全局的最优解的概率是越大的。

那么,我们可以得到第一个结论:时间越多,答案的正确性越大。

还有很多的题目,是无法控制步长的(甚至连函数都不算),我们经过实验,可以发现:衰减系数越大,答案的正确性越大。

然后,尽可能的推出一些性质(比如贪心构造,即使可能正确性有误),尽量的靠近最优解。这样的正确性比较高。(比如 NOIP T3)

下面给出一个简单的模板:

1
2
3
4
5
6
7
8
9
10
11
12
int sa()
{
int ans = INF, now;
double k = ;//自定义参数
for (double t = 1e7; t > 1e-7; t *= 0.999997)//衰减参数
{//先生成新解 cur
if (clock() / CLOCKS_PER_SEC > 0.9) return ans;
int cur = get_ans();
if (cur < now || exp(- (cur - now) * t * k) < rand() / RAND_MAX) now = cur;//比较两解,如果大的话有一定概率跳过去
}
return ans;
}

4. 例题

T1:[NOIP2021 T3]方差

题目传送门 Luogu

注意至少要推出差分的性质,得分会至少有 50-70 pts。

给出考场的 96 pts 代码,Luogu 民间数据 88 pts。(去掉了 freopen,中文注释是后来加的)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <ctime>
#include <cmath>
using namespace std;

typedef long long ll;
const int N = 1e5 + 10;
int a[N], n, b[N], tmp[N], pre[N], suf[N];

template <class T>
inline void read(T &x)
{
static char c;
bool flag = 0;
while ((c = getchar()) < '0' || c > '9')
if (c == '-') flag = 1;
x = c - '0';
while ((c = getchar()) >= '0' && c <= '9')
x = (x << 1) + (x << 3) + c - '0';
if (flag) x = -x;
}

template <class T, class ...T1>
inline void read(T &x, T1 &...x1)
{
read(x), read(x1...);
}

ll get_S()//得到方差
{
a[1] = 0;
for (int i = 2; i <= n; ++ i) a[i] = a[i - 1] + tmp[i - 1];
ll sum1 = 0, sum2 = 0;
for (int i = 1; i <= n; ++ i)
sum1 += a[i], sum2 += a[i] * a[i];
return sum2 * n - sum1 * sum1;
}

/*void solve_to_right(int l, int r)
{
// for (int times = 1; times <= r - l + 1; ++ times)
for (int i = l; i <= r; ++ i)
a[i] = min(a[i], a[i - 1] + a[i + 1] - a[i]);
}

void solve_to_left(int l, int r)
{
// for (int times = 1; times <= r - l + 1; ++ times)
for (int i = r; i >= l; -- i)
a[i] = max(a[i], a[i - 1] + a[i + 1] - a[i]);
}

void solve(int l, int r)
{
if (r - l + 1 <= 2) return;
if (r - l + 1 == 3)
{
if (a[l] - a[1] > a[n] - a[r]) a[l + 1] = min(a[l + 1], a[l] + a[r] - a[l + 1]);
else a[l + 1] = max(a[l + 1], a[l] + a[r] - a[l + 1]);
return;
}
if (a[l] - a[1] >= a[n] - a[r]) solve_to_right(l + 1, r - 1), solve_to_left(l + 1, r - 2);
else solve_to_left(l + 1, r - 1), solve_to_right(l + 2, r - 1);
solve(l + 1, r - 1);
}*/

void sa()
{
for (double t = 1e9; t > 1e-7; t *= 0.9999997)
{
// cout << t << endl;
ll now = get_S(), nw;
int i = rand() % (n - 1) + 1, j = rand() % (n - 1) + 1;
if ((1.0 * clock() / CLOCKS_PER_SEC) > 0.9) return;
if (tmp[i] == tmp[j]) continue;
swap(tmp[i], tmp[j]);
nw = get_S();
// printf("%lld %lld\n", now, nw);
if (nw > now)
if (exp(-(1.0 * nw - now) / now * t) * RAND_MAX < rand()) swap(tmp[i], tmp[j]);
}
}

int main()
{
// freopen("variance.in", "r", stdin);
// freopen("variance.out", "w", stdout);
srand(2206704740U);
read(n);
for (int i = 1; i <= n; ++ i) read(a[i]);
for (int i = 1; i < n; ++ i) b[i] = a[i + 1] - a[i];
sort(b + 1, b + n);
reverse(b + 1, b + n);
for (int i = 1, j = n - 1, tot = 0; i <= j;)
{//构造贪心,随意构造,可以不管
if (pre[i - 1] <= suf[j + 1]) tmp[i] = b[++ tot], pre[i] = pre[i - 1] + tmp[i], i ++;
else tmp[j] = b[++ tot], suf[j] = suf[j + 1] + tmp[j], j --;
}
sa();
// solve(1, n);
// for (int i = 1; i < n; ++ i) cout << tmp[i] << " \n"[i == n - 1];
printf("%lld\n", get_S());
return 0;
}