本文代码基于zig 0.15.2

Base64算法介绍#

Base64本质上是一种将二进制数据转换为文本格式的编码方式。通过将每三个字节的数据转换为四个可打印字符来实现这一点。

为啥需要它呢?因为在一些系统传输中,它只能接受文本数据,比如邮件系统(smtp)和某些网络协议(http)。通过Base64编码,我们可以确保二进制数据在这些系统中能够安全传输。我们将二进制数据转换为文本格式后,接收端再根据Base64规则将其还原为原始的二进制数据。因此,如果出于任何原因,你需要在邮件中以附件形式发送二进制文件(例如PDF或Excel文件),这些二进制文件在纳入SMTP消息前通常会被转换为base64。所以base64编码在这些邮件系统中被广泛用于将二进制数据嵌入SMTP消息。

Base64通过将每三个字节的数据转换为四个可打印字符来实现这一点。

为啥是3个字节转换为4字符?因为3字节等于24位,而Base64使用6位来表示一个字符,所以24位刚好可以被划分为4个6位的块。

因此,Base64算法的运作方式是每次将3字节转换为4个Base64字符。它会遍历输入字符串,每次处理3字节,并将其转换为Base64字符,每次迭代生成4个字符。该算法会持续迭代并生成这些“新字符”,直到处理完整个输入字符串。

Base64字符集(编码表)#

Base64字符集就是64个可打印字符的集合。也就是说生成的Base64编码字符串只能包含这些字符。

  • 大写字母:A-Z (26个字符)
  • 小写字母:a-z (26个字符)
  • 数字:0-9 (10个字符)
  • 特殊字符:+ 和 / (2个字符)

还有一个特殊的填充字符=,用于确保编码后的字符串长度是4的倍数(补齐)。

编码#

原理如图,我这里直接截个网图:

01

  1. 先将输入的字符串用二进制表示
  2. 将这些二进制按照6位一组进行拆分(从左到右的顺序,如果一组不足6位,则用0进行补齐。每3个字节会被拆分为4组6位,如果最终不足4组,则用填充字符=补齐)
  3. 将每组6位转换为对应的十进制数
  4. 根据Base64编码表,将这些十进制数转换为对应的字符(根据上面列出的Base64字符集顺序)

这样就得到了Base64编码后的字符串。

解码#

请记住,base64解码器本质上是在逆转base64编码器执行的操作。

原理如图:

02

  1. 解码这,输入的是Base64编码后的字符串,这里先将每个字符转换为对应的十进制数(根据Base64编码表,通过字符找到编码表对应的索引,注意要忽略填充字符=)
  2. 因为输入的是Base64编码后的字符,所以每个字符对应6位二进制数,此时需要将这些6位二进制数重新组合成原始的8位字节(每4个Base64字符会被转换为3个字节)
  3. 最后将这些字节重新组合成原始的字符串(直接打印即可,在zig中字符串本质就是字节数组,和c一样。注意忽略填充字符=)

结构体及基础方法实现#

定义一个结构体,包含上面说的编码表,和一些后续我们可能用到的方法。

const Base64 = struct {
    _table: *const [64]u8,

    pub fn init() Base64 {
        return Base64{
            ._table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
        };
    }

    fn _char_at(self: Base64, i: usize) u8 {
        return self._table[i];
    }

    fn _char_index(self: Base64, c: u8) u8 {
        if (c == '=') {
            return 64;
        }

        var res: u8 = 0;

        for (self._table) |val| {
            if (val == c) {
                return res;
            }
            res += 1;
        }

        return res;
    }
};

这里我们定义了一个Base64结构体,并在init方法中初始化了Base64编码表。_char_at方法用于根据索引获取对应的Base64字符。

_char_index方法用于根据Base64字符获取对应的索引。Base64字符集中的填充字符=被特殊处理,返回64。

计算编码和解码所需的缓冲区大小#

通过上面大致介绍的编码和解码原理,我们清楚,编码本质上Base64就是将3个字节转换为4个字符,解码就是逆操作。

如果我们输入的字节数大于3的话,在代码中就需要循环处理这些字节。先处理3个,然后存下来,接着处理下3个,直到处理完整个输入字节数组。解码也一样,反正就是需要找个地方存下来。

此时就涉及到分配内存了,分配内存也就意味着我们需要算出编码和解码所需的缓冲区(内存)大小。

编码所需要的缓冲区大小:先将输入字符数除以3然后向上取整,最后乘以4。向上取整是因为如果输入的字节数不是3的倍数(此时会有剩余字节),我们仍然需要为这些剩余字节分配4个字符的空间。

同时也要注意下输入字节数小于3的情况,这种情况下我们也需要分配4个字符的空间。代码如下:

fn _cal_encode_length(input: []const u8) !usize {
   if (input.len < 3) {
       return 4;
   }

   const n: usize = try std.math.divCeil(usize, input.len, 3);
   return n * 4;
}

std.math.divCeil函数用于计算向上取整的结果。

解码所需要的缓冲区大小:先将输入字符数除以4然后向下取整,最后乘以3。向下取整是因为如果输入的Base64编码字符串不是4的倍数(此时会有剩余字符),我们只能处理完整的4字符块,剩余字符无法解码。

在这里编码向上取整是因为不能放过任何一个输入字节,而解码向下取整是因为不能处理不完整的4字符块。计算解码所需缓冲区大小代码如下:

fn _cal_decode_length(input: []const u8) !usize {
    if (input.len < 4) {
        return 3;
    }

    var n: usize = try std.math.divFloor(usize, input.len, 4);
    n *= 3;

    var i: usize = input.len - 1;
    while (i > 0) : (i -= 1) {
        if (input[i] == '=') {
            n -= 1;
        } else {
            break;
        }
    }

    return n;
}

这里也是差不多,需要考虑输入字符数小于4的情况。同时还需要处理填充字符=,每个填充字符都意味着少了1个字节,所以我们需要在计算结果中减去相应的字节数。std.math.divFloor函数用于计算向下取整的结果。

编码逻辑实现#

我们知道原理后,剩下的编码操作就只是搬砖了。本质上就是每次处理3个字节(或者说3个ascii字符,即3个8位二进制),然后转换为4个Base64字符,直到处理完输入字符。

编码的核心就是将3个8位二进制转换为4个6位二进制,然后根据Base64编码表获取对应的字符。

具体怎么转换?位运算,比如第一个字节的前6位直接作为第一个6位二进制,第一个后2位和第二个字节的前4位组合成第二个6位二进制,依次类推。

这里涉及到两个位运算,一个是右移操作>>,另一个是按位与操作&。举个例子,我们使用右移操作将第一个字节右移2位,就得到了第一个字节的前6位,然后我们将原始的第一个字节再和0b00000011按位与操作,就得到了第一个字节的后2位,接着加上第二个字节右移4位的结果,就得到了第二个6位二进制。

同时注意一些边界条件,比如输入字符数不是3的倍数时的处理。

  1. 当尾部的输入字符为2个时,我们需要在编码后的字符串末尾添加1个填充字符=
  2. 当尾部的输入字符为1个时,我们需要在编码后的字符串末尾添加2个填充字符=

编码操作涉及到内存分配,所以编码函数需要外部传入一个内存分配器。

代码如下:

    // encode input bytes to base64 string
    pub fn encode(
        self: Base64,
        allocator: std.mem.Allocator,
        input: []const u8,
    ) ![]u8 {
        if (input.len == 0) {
            return "";
        }

        const n_output: usize = try _cal_encode_length(input);
        var out = try allocator.alloc(u8, n_output);
        var buf = [3]u8{ 0, 0, 0 };
        var count: u8 = 0;
        var iout: usize = 0;

        for (input, 0..) |_, i| {
            buf[count] = input[i];
            count += 1;

            if (count == 3) {
                // 三字符窗口
                out[iout] = self._char_at(buf[0] >> 2);
                out[iout + 1] = self._char_at(((buf[0] & 0b00000011) << 4) + (buf[1] >> 4));
                out[iout + 2] = self._char_at(((buf[1] & 0b00001111) << 2) + (buf[2] >> 6));
                out[iout + 3] = self._char_at(buf[2] & 0b00111111);

                iout += 4;
                count = 0;
            }
        }

        // 单字符窗口判断
        if (count == 1) {
            out[iout] = self._char_at(buf[0] >> 2);
            out[iout + 1] = self._char_at((buf[0] & 0b00000011) << 4);
            // 补齐
            out[iout + 2] = '=';
            out[iout + 3] = '=';
        }

        // 双字符窗口判断
        if (count == 2) {
            out[iout] = self._char_at(buf[0] >> 2);
            out[iout + 1] = self._char_at(((buf[0] & 0b00000011) << 4) + (buf[1] >> 4));
            out[iout + 2] = self._char_at((buf[1] & 0b00001111) << 2);
            // 补齐
            out[iout + 3] = '=';
        }

        return out;
    }

每次循环处理3个字节,将编码结果保存至buf数组中,处理完3个字节后,将结果写入输出数组out中。count变量就用于记录当前处理了多少个字节。iout变量用于记录输出数组的写入位置。最后判断count的值,处理剩余的1或2个字节,并添加相应的填充字符=

上述的一些位移操作,各位自己拿草稿纸画一下就明白了。

解码逻辑实现#

解码操作和编码操作类似,本质上就是每次处理4个Base64字符,然后转换为3个字节,直到处理完输入的Base64编码字符串。

有几个不同的点:

  1. 要将输入的Base64字符转换为对应的索引值(6位二进制)
  2. 位移操作不同,因为是逆操作。比如获取第一个字节时,需要将第一个6位二进制左移2位,然后加上第二个6位二进制右移4位的结果(此时就通过2个6位二进制字节拼成1个8位二进制字节了)。接着第二位字节就是原始的第二个6位二进制左移4位,加上第三个6位二进制右移2位的结果,依次类推。
  3. 一些边界条件不同,因为这次是4个Base64字符转换为3个字节。所以一些变量的判断也不同。
  4. 记得忽略填充字符=,因为它们不参与解码。(因为它没有意义)
    // decode base64 string to bytes
    pub fn decode(
        self: Base64,
        allocator: std.mem.Allocator,
        input: []const u8,
    ) ![]u8 {
        if (input.len == 0) {
            return "";
        }

        const n_output: usize = try _cal_decode_length(input);
        var out = try allocator.alloc(u8, n_output);
        var count: u8 = 0;
        var iout: u64 = 0;
        var buf = [4]u8{ 0, 0, 0, 0 };

        for (0..input.len) |i| {
            buf[count] = self._char_index(input[i]);
            count += 1;

            if (count == 4) {
                out[iout] = (buf[0] << 2) + (buf[1] >> 4);

                if (buf[2] != 64) {
                    out[iout + 1] = (buf[1] << 4) + (buf[2] >> 2);
                }

                if (buf[3] != 64) {
                    out[iout + 2] = (buf[2] << 6) + buf[3];
                }

                iout += 3;
                count = 0;
            }
        }

        return out;
    }

buf数组上限为4,因为每次处理4个Base64字符。count变量用于记录当前处理了多少个Base64字符,当其为4时,说明我们需要将里面的元素进行解码操作了。iout变量用于记录输出数组的写入位置。

还记得前面的_char_index方法返回64代表着什么吗?就是=。需要忽略它。

Base64字符串,每4位一组,只有其最后两位可能是填充字符=,所以我们只需要在解码时判断buf[2]buf[3]是否为64即可。

测试代码示例#

pub fn main() !void {
    const b = Base64.init();
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();

    // encode
    const encoded = try b.encode(allocator, "Hello");
    defer allocator.free(encoded);

    std.debug.print("Encoded: {s}\n", .{encoded});

    // decode
    const decoded = try b.decode(allocator, encoded);
    defer allocator.free(decoded);

    std.debug.print("Decoded: {s}\n", .{decoded});
}

终端输出:

$ ./hello
Encoded: SGVsbG8=
Decoded: Hello