GJK算法计算凸多边形之间的距离

时间:2022-07-23
本文章向大家介绍GJK算法计算凸多边形之间的距离,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。

缘起

《你被追尾了续》中我们学习了 GJK 碰撞检测算法. 但其实 GJK 算法发明出来的初衷是计算凸多边形之间的距离的. 所以我们来学习一下这种算法.

分析

根据《你被追尾了续》的学习,我们知道,其实就是求 坐标原点到Minkowski和(也是一个凸多边形)的距离. 以下图为例,显然shape1(三角形)和 shape2(四边形)没有交集,然后我们想计算它俩之前的距离

做出它俩的 Minkowski 和如下

所以答案就是 OD 的长度. 所以我们自然要解决的问题是,怎么快速能知道答案是坐标原点到 (-4,1) 到 (1, 3) 的线段的距离呢? 和 GJK 碰撞检测类似的,我们不能

O(n^2)

暴力枚举 Minkowski和的所有的点,这里也是使用迭代.

其实和GJK碰撞检测完全类似,我们也需要用supprt 函数迭代的构建单纯形,伪代码如下

d = c2 - c1; // 和GJK碰撞检测类似
Simplex.add(support(shape1, shape2, d)); // Simplex 中加入 a 点
Simplex.add(support(shape1, shape2, -d));  // Simplex 中加入 b 点
// 从原点指向 ab 线段上距离原点最近的点的向量, 例如恰好就是答案的话, 则 d.magnitude() 就是答案
d = ClosestPointToOrigin(Simplex.a, Simplex.b);
while (true) {
  // 上面的 ClosestPointToOrigin 得到的 d 方向是从原点到 ab 线段上的最近点的, 所以将d反向, 则指向原点
  d.negate();
  if (d.isZero()) {
    // 则 原点在 Minkowski 和中, 所以发生了碰撞, 但是这当然不是碰撞的唯一情形
    return false;
  }
  // 计算出新的点c, 注意, 因为 d 是朝向坐标原点的
  c = support(shape1, shape2, d);
  // c 在 d 方向上的投影
  double dc = c.dot(d);
  // a 在 d 方向上的投影 
  double da = Simplex.a.dot(d);
  // 这个 tolerance 其实对于多边形而言, 取 0 就好了, tolerance 是对于 曲边形才需要设置
  if (|dc - da| < tolerance) {
    distance = d.magnitude();
    return true;
  }
  // 因为我们已经知道了 c 比 a、b更接近原点, 所以只需要保留 a、b中更接近原点的那个点了
  // 然后保留下来的点和c共同构成新的单纯形(即一条线段)
  p1 = ClosestPointToOrigin(Simplex.a, c);
  p2 = ClosestPointToOrigin(Simplex.b, c);
  if (p1.magnitude() < p2.magnitude()) {
    Simplex.b = c;
    d = p1;
  } else {
    Simplex.a = c;
    d = p2;
  }
}

和GJK碰撞检测中的伪代码类似,但是有一个重要的区别在于上面的伪代码始终保持 单纯形S 中只有2个点. 用上面的Figure 1的例子解释上面的伪代码

首先是初始化

d = (11.5, 4.0) - (5.5, 8.5) = (6, -4.5)
Simplex.add(support(shape1, shape2, d)) = (9, 9) - (8, 6) = (1, 3)
Simplex.add(support(shape1, shape2, -d)) = (4, 11) - (13, 1) = (-9, 10)
// 计算新的 d, 直接看 Figure 4 即可,因为 ab 线段上距离原点最近的点就是 (1,3), 所以 d = (1, 3)
d = (1, 3) 
d = (-1, -3) // d 反向

然后开始迭代(对应 Figure 5)

// 用support 函数计算新的点
c = support(shape1, shape2, d) = (4, 5) - (15, 6) = (-11, -1)
// 14 - (-10) = 24 还不够小,  所以我们不能终止循环, 事实上, 对于凸多边形的情况, 只有为 0 了, 才算够小
dc = 11 + 3 = 14
da = -1 - 9 = -10
// 边 AC [(1, 3) to (-11, -1)] 和 边 BC [(-9, 10) to (-11, -1)]相比更加接近原点, 所以应该蒯掉 b 点, 也就是将 b 替换为 新发现的, 更加接近原点的 c
b = c
// 设置新的 d, 我们实际上是朝着原点在不断进发
d = p1

注意,我们的单纯形从线段 ab 变为了 线段 ac, 所以我们的单纯形其实更加靠近坐标原点了.

然后再进行一次迭代

d = (1.07, -1.34) // 即上面的 p1 的相反向量 -p1
// 根据
c = support(shape1, shape2, d) = (9, 9) - (8, 6) = (1, 3)
// 够小了!
dc = -1.07 + 4.02 = 2.95
da = -1.07 + 4.02 = 2.95
// ||d|| = 1.7147886 我们完成了迭代
distance = 1.71

于是答案就是 1.71 了.

注意,如果 shape1 和 shape2 是凸多边形的话,则最后 dc 是一定等于 da 的. 如果 shape1 或者 shape2 中有一个是曲边的,则最后 dc 和 da 之间的距离差可能就不是 0 了. 所以如果选择的 tolerance 太小了的话,则可能一直达不到 tolerance而无限循环,所以我们应该加一个最大循环次数.

在两个物体本来就交叉的情况下,这个算法可能终止条件会失效,从而带来一些问题。一般情况下,我们都会先做碰撞检测,然后再求他们之间的距离

还有一个有趣的问题是,我们已经能求出两个凸多边形的距离了,那么你能更进一步求出产生这个距离的那对点吗? (如果有多对,随意产生一对就行) 其实也很简单,就拿上面的例子来说,

因为我们知道原点到 Minkowski 和的距离是 坐标原点到线段 (1,3)---(-4,-1) 的距离

我们只需要维护一下每个 Minkowski 和上的点是由哪对 shape1 、shape2 上的点构成的就行了, 例如下表

shape1

shape2

对应Minkowski和上的点

(9, 9)

(8, 6)

(1, 3)

(4, 5)

(8, 6)

(-4, -1)

所以我们就知道了,最后实现该最近距离1.71的实际上是 shape1上的线段 (9,9)--(4,5) 到 shape2 上的线段 (8,6)--(8,6) 的距离. 而求两根线段之间的最短距离的实现点对就很简单了.

以下面一道经典的题目来证明上面的算法正确.

题目概述
给定两个不相交的凸多边形,求其之间最近距离

时限
1000ms 64MB

输入
第一行正整数N,M,代表两个凸多边形顶点数,其后N行,每行两个浮点数x,y,描述多边形1的一个点的坐标,其后M
行,每行两个浮点数x,y,描述多边形2的一个点的坐标,输入到N=M=0为止
输入保证是按照顺时针或者逆时针给出凸包上的点.

限制
3<=N,M<=10000;-10000<=x,y<=10000

输出
每行一个浮点数,为所求最近距离,误差在1e-3内均视为正确

样例输入
4 4
0.00000 0.00000
0.00000 1.00000
1.00000 1.00000
1.00000 0.00000
2.00000 0.00000
2.00000 1.00000
3.00000 1.00000
3.00000 0.00000
0 0

样例输出
1.00000

可以使用旋卡. 但这里使用上述 GJK 算法. GJK 算法不要求多边形输入的顶点的顺序——也就是哪怕你乱序输入都行. 所以这也是一种相比旋卡的优势

//#include "stdafx.h"
//#define LOCAL
#pragma GCC optimize(2)
#pragma G++ optimize(2)
#pragma warning(disable:4996)
//#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <stdio.h>
#include <iostream>
#include <iomanip>
#include <string>
#include <ctype.h>
#include <string.h>
#include <fstream>
#include <sstream>
#include <math.h>
#include <map>
//#include <unordered_map>
#include <algorithm>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <time.h>
#include <stdlib.h>
#include <bitset>
using namespace std;
//#define int unsigned long long
//#define int long long
#define re register int
#define ci const int
#define ui unsigned int 
typedef pair<int, int> P;
#define FE(cur) for(re h = head[cur], to; ~h; h = g[h].nxt)
#define ilv inline void
#define ili inline int
#define ilc inline char
#define ild inline double
#define ilp inline P
#define LEN(cur) (hjt[cur].r - hjt[cur].l)
#define MID(cur) (hjt[cur].l + hjt[cur].r >> 1)
#define SQUARE(x) ((x) * (x))
typedef vector<int>::iterator vit;
typedef set<int>::iterator sit;
typedef map<int, int>::iterator mit;
const int inf = ~0u>>1;
const double PI = acos(-1.0), eps = 1e-8;
namespace fastio
{
    const int BUF = 1 << 21;
    char fr[BUF], fw[BUF], *pr1 = fr, *pr2 = fr;int pw;
    ilc gc() { return pr1 == pr2 && (pr2 = (pr1 = fr) + fread(fr, 1, BUF, stdin), *pr2 = 0, pr1 == pr2) ? EOF : *pr1++; }
    ilv flush() { fwrite(fw, 1, pw, stdout); pw = 0; }
    ilv pc(char c) { if (pw >= BUF) flush(); fw[pw++] = c; }
    ili read(int &x)
    {
        x = 0; int f = 1; char c = gc(); if (!~c) return EOF;
        while(!isdigit(c)) { if (c == '-') f = -1; c = gc(); }
        while(isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = gc();
        x *= f; return 1;
    }
    ili read(double &x) 
    {
        int xx = 0; double f = 1.0, fraction = 1.0; char c = gc(); if (!~c) return EOF;
        while (!isdigit(c)) { if (c == '-') f = -1.0; c = gc(); }
        while (isdigit(c)) { xx = (xx << 3) + (xx << 1) + (c ^ 48), c = gc(); }
        x = xx;
        if (c ^ '.') { x = f * xx; return 1; }
        c = gc();
        while (isdigit(c)) x += (c ^ 48) * (fraction /= 10), c = gc();
        x *= f; return 1;
    }
    ilv write(int x) { if (x < 0) pc('-'), x = -x; if (x > 9) write(x / 10); pc(x % 10 + 48); }
    ilv writeln(int x) { write(x);pc(10); }
    ili read(char *x)
    {
        char c = gc(); if (!~c) return EOF;
        while(!isalpha(c) && !isdigit(c)) c = gc();
        while (isalpha(c) || isdigit(c)) *x++ = c, c = gc();
        *x = 0; return 1;
    }
    ili readln(char *x)
    {
        char c = gc(); if (!~c) return EOF;
        while(c == 10) c = gc();
        while(c >= 32 && c <= 126) *x++ = c, c = gc();
        *x = 0; return 1;
    }
    ilv write(char *x) { while(*x) pc(*x++); }
    ilv write(const char *x) { while(*x) pc(*x++); }
    ilv writeln(char *x) { write(x); pc(10); }
    ilv writeln(const char *x) { write(x); pc(10); }
    ilv write(char c) { pc(c); }
    ilv writeln(char c) { write(c); pc(10); }
} using namespace fastio;
const int maxn = 1e4+5;
int n, m;
struct Point
{
    double x, y;
    Point(double x = 0, double y = 0): x(x), y(y){};
    Point operator - (Point o) 
    {
        return Point(x - o.x, y - o.y);
    }
    double operator / (Point o) 
    {
        return x * o.y - y * o.x;
    }
    double operator * (Point o) 
    {
        return x * o.x + y * o.y;
    }
    Point neg()
    {
        return Point(-x, -y);
    }
    double magnitude()
    {
        return sqrt(SQUARE(x) + SQUARE(y));
    }
    Point scalar(double a)
    {
        return Point(x * a, y *a);
    }
    Point normalize()
    {
        return scalar(1 / magnitude());
    }
} shape1[maxn], shape2[maxn], d;
stack<Point> s;

Point center(Point *shape, int n)
{
    Point ans;
    for (re i = 0; i < n; i++)
    {
        ans.x += shape[i].x;
        ans.y += shape[i].y;
    }
    ans.x /= n, ans.y /= n;
    return ans;
}

Point support1(Point *shape, int n, Point d)
{
    double mx = -inf, proj;
    Point ans;
    for (re i = 0; i < n; i++)
    {
        proj = shape[i] * d;
        if (mx < proj)
        {
            mx = proj;
            ans = shape[i];
        }
    }
    return ans;
}

Point support(Point *shape1, Point *shape2, int n1, int n2, Point d)
{
    Point x = support1(shape1, n1, d), y = support1(shape2, n2, d.neg());
    return x - y;
}

ili dcmp(double x)
{
    if (fabs(x) < eps) return 0;
    return x < 0 ? -1 : 1;
}

Point perp(Point &a, Point &b, Point &c)
{
    return b.scalar(a * c) - a.scalar(b * c);
}

Point closestPointToOrigin(Point &a, Point &b)
{
    double da = a.magnitude();
    double db = b.magnitude();
    double dis = fabs(a / b) / (a - b).magnitude();
    Point ab = b - a, ba = a - b, ao = a.neg(), bo = b.neg();
    if (ab * ao > 0 && ba * bo > 0) return perp(ab, ao, ab).normalize().scalar(dis);
    return da < db ? a.neg() : b.neg();
}

ild gjk(Point *shape1, Point *shape2, int n1, int n2)
{
    d = center(shape2, n2) - center(shape1, n1);
    Point a = support(shape1, shape2, n1, n2, d);
    Point b = support(shape1, shape2, n1, n2, d.neg());
    Point c, p1, p2;
    d = closestPointToOrigin(a, b);
    s.push(a);
    s.push(b);
    while (1)
    {
        c = support(shape1, shape2, n1, n2, d);
        a = s.top(); s.pop();
        b = s.top(); s.pop();
        double da = d * a, db = d * b, dc = d * c;
        if (!dcmp(dc - da) || !dcmp(dc - db)) return d.magnitude();
        p1 = closestPointToOrigin(a, c);
        p2 = closestPointToOrigin(b, c);
        p1.magnitude() < p2.magnitude() ? s.push(a), d = p1 : s.push(b), d = p2;
        s.push(c);
    }
}

signed main()
{
#ifdef LOCAL
    FILE *ALLIN = freopen("d:\data.in", "r", stdin);
//  freopen("d:\my.out", "w", stdout);
#endif
    while (read(n), read(m), n + m)
    {
        for (re i = 0; i < n; i++) read(shape1[i].x), read(shape1[i].y);
        for (re i = 0; i < m; i++) read(shape2[i].x), read(shape2[i].y);
        printf("%.5lfn", gjk(shape1, shape2, n, m));
    }
    flush();
#ifdef LOCAL
    fclose(ALLIN);
#endif
    return 0;
}

ac情况

Accepted 16ms 2536KB C++

这里进一步解释一下上面的 closestPointToOrigin函数, 简而言之,

closestPointToOrigin(a, b) 返回的就是上图中的

overrightarrow{cO}