matplotlib 三维绘制散点与线段

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
from laspy.file import File


def read_las_ins(las_file):
    las = File(las_file, mode='r')
    x = las.x
    y = las.y
    z = las.z
    classification = las.classification  # 标签,这里数据类别是已经直接存在las文件中的
    red = np.unique(las.red, False, True)[1]    # 假设实例标签写在了R字段
    points = np.vstack((x, y, z, classification, red)).transpose()
    return points


def plot_3d_with_line(xyz, classes=None, s=0.5, lines=None, cm='Set1'):
    cm = plt.get_cmap(cm)
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.set_xlabel(xlabel='x/m',
                  fontdict=None,
                  labelpad=0,  # default: 4.0
                  loc='right',  # default: 'center'
                  )
    ax.set_ylabel(ylabel='y/m',
                  fontdict=None,
                  labelpad=0,  # default: 4.0
                  # loc='left',  # default: 'center'
                  )
    ax.set_zlabel(zlabel='z/m',
                  fontdict=None,
                  labelpad=0,  # default: 4.0
                  # loc='left',  # default: 'center'
                  )
    if classes is not None:
        ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], s=s, c=classes, cmap=cm)
    else:
        ax.scatter(xyz[:, 0], xyz[:, 1], xyz[:, 2], s=s, cmap=cm)

    if lines is not None:
        for line in lines:
            if line[0][2] != line[1][2]:
                z = np.linspace(line[0][2], line[1][2])
                x = (line[1][0] - line[0][0]) / (line[1][2] - line[0][2]) * z + \
                    line[1][0] - (line[1][0] - line[0][0]) / (line[1][2] - line[0][2]) * line[0][0]
                y = (line[1][1] - line[0][1]) / (line[1][2] - line[0][2]) * z + \
                    line[1][1] - (line[1][1] - line[0][1]) / (line[1][2] - line[0][2]) * line[0][1]
                ax.plot3D(x, y, z, 'green')

            elif line[0][1] != line[1][1]:
                y = np.linspace(line[0][1], line[1][1])
                x = (line[1][0] - line[0][0]) / (line[1][1] - line[0][1]) * y + \
                    line[1][0] - (line[1][0] - line[0][0]) / (line[1][1] - line[0][1]) * line[0][0]
                z = (line[1][2] - line[0][2]) / (line[1][1] - line[0][1]) * y + \
                    line[1][2] - (line[1][2] - line[0][2]) / (line[1][1] - line[0][1]) * line[0][2]
                ax.plot3D(x, y, z, 'green')
            else:
                x = np.linspace(line[0][0], line[1][0])
                y = (line[1][1] - line[0][1]) / (line[1][0] - line[0][0]) * x + \
                    line[1][1] - (line[1][1] - line[0][1]) / (line[1][0] - line[0][0]) * line[0][1]
                z = (line[1][2] - line[0][2]) / (line[1][0] - line[0][0]) * x + \
                    line[1][2] - (line[1][2] - line[0][2]) / (line[1][0] - line[0][0]) * line[0][2]
                ax.plot3D(x, y, z, 'green')
    plt.show()
    return True


las_file = "xxx.las"

points = read_las_ins(las_file)

xyz = points[:, : 3]
classes = points[:, 3]
instance = points[:, 4]

lines = [((0,0,0), (100, 50, 100))]
plot_3d_with_line(xyz=xyz, classes=instance, lines=lines, cm='tab20')