使用Zig实现base64编码和解码
TOC
本文代码基于zig 0.15.2
Base64算法介绍#
Base64本质上是一种将二进制数据转换为文本格式的编码方式。通过将每三个字节的数据转换为四个可打印字符来实现这一点。
为啥需要它呢?因为在一些系统传输中,它只能接受文本数据,比如邮件系统(smtp)和某些网络协议(http)。通过Base64编码,我们可以确保二进制数据在这些系统中能够安全传输。我们将二进制数据转换为文本格式后,接收端再根据Base64规则将其还原为原始的二进制数据。因此,如果出于任何原因,你需要在邮件中以附件形式发送二进制文件(例如PDF或Excel文件),这些二进制文件在纳入SMTP消息前通常会被转换为base64。所以base64编码在这些邮件系统中被广泛用于将二进制数据嵌入SMTP消息。
Base64通过将每三个字节的数据转换为四个可打印字符来实现这一点。
为啥是3个字节转换为4字符?因为3字节等于24位,而Base64使用6位来表示一个字符,所以24位刚好可以被划分为4个6位的块。
因此,Base64算法的运作方式是每次将3字节转换为4个Base64字符。它会遍历输入字符串,每次处理3字节,并将其转换为Base64字符,每次迭代生成4个字符。该算法会持续迭代并生成这些“新字符”,直到处理完整个输入字符串。
Base64字符集(编码表)#
Base64字符集就是64个可打印字符的集合。也就是说生成的Base64编码字符串只能包含这些字符。
- 大写字母:A-Z (26个字符)
- 小写字母:a-z (26个字符)
- 数字:0-9 (10个字符)
- 特殊字符:+ 和 / (2个字符)
还有一个特殊的填充字符=,用于确保编码后的字符串长度是4的倍数(补齐)。
编码#
原理如图,我这里直接截个网图:
![]()
- 先将输入的字符串用二进制表示
- 将这些二进制按照6位一组进行拆分(从左到右的顺序,如果一组不足6位,则用0进行补齐。每3个字节会被拆分为4组6位,如果最终不足4组,则用填充字符
=补齐) - 将每组6位转换为对应的十进制数
- 根据Base64编码表,将这些十进制数转换为对应的字符(根据上面列出的Base64字符集顺序)
这样就得到了Base64编码后的字符串。
解码#
请记住,base64解码器本质上是在逆转base64编码器执行的操作。
原理如图:
![]()
- 解码这,输入的是Base64编码后的字符串,这里先将每个字符转换为对应的十进制数(根据Base64编码表,通过字符找到编码表对应的索引,注意要忽略填充字符
=) - 因为输入的是Base64编码后的字符,所以每个字符对应6位二进制数,此时需要将这些6位二进制数重新组合成原始的8位字节(每4个Base64字符会被转换为3个字节)
- 最后将这些字节重新组合成原始的字符串(直接打印即可,在zig中字符串本质就是字节数组,和c一样。注意忽略填充字符
=)
结构体及基础方法实现#
定义一个结构体,包含上面说的编码表,和一些后续我们可能用到的方法。
const Base64 = struct {
_table: *const [64]u8,
pub fn init() Base64 {
return Base64{
._table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
};
}
fn _char_at(self: Base64, i: usize) u8 {
return self._table[i];
}
fn _char_index(self: Base64, c: u8) u8 {
if (c == '=') {
return 64;
}
var res: u8 = 0;
for (self._table) |val| {
if (val == c) {
return res;
}
res += 1;
}
return res;
}
};
这里我们定义了一个Base64结构体,并在init方法中初始化了Base64编码表。_char_at方法用于根据索引获取对应的Base64字符。
_char_index方法用于根据Base64字符获取对应的索引。Base64字符集中的填充字符=被特殊处理,返回64。
计算编码和解码所需的缓冲区大小#
通过上面大致介绍的编码和解码原理,我们清楚,编码本质上Base64就是将3个字节转换为4个字符,解码就是逆操作。
如果我们输入的字节数大于3的话,在代码中就需要循环处理这些字节。先处理3个,然后存下来,接着处理下3个,直到处理完整个输入字节数组。解码也一样,反正就是需要找个地方存下来。
此时就涉及到分配内存了,分配内存也就意味着我们需要算出编码和解码所需的缓冲区(内存)大小。
编码所需要的缓冲区大小:先将输入字符数除以3然后向上取整,最后乘以4。向上取整是因为如果输入的字节数不是3的倍数(此时会有剩余字节),我们仍然需要为这些剩余字节分配4个字符的空间。
同时也要注意下输入字节数小于3的情况,这种情况下我们也需要分配4个字符的空间。代码如下:
fn _cal_encode_length(input: []const u8) !usize {
if (input.len < 3) {
return 4;
}
const n: usize = try std.math.divCeil(usize, input.len, 3);
return n * 4;
}
std.math.divCeil函数用于计算向上取整的结果。
解码所需要的缓冲区大小:先将输入字符数除以4然后向下取整,最后乘以3。向下取整是因为如果输入的Base64编码字符串不是4的倍数(此时会有剩余字符),我们只能处理完整的4字符块,剩余字符无法解码。
在这里编码向上取整是因为不能放过任何一个输入字节,而解码向下取整是因为不能处理不完整的4字符块。计算解码所需缓冲区大小代码如下:
fn _cal_decode_length(input: []const u8) !usize {
if (input.len < 4) {
return 3;
}
var n: usize = try std.math.divFloor(usize, input.len, 4);
n *= 3;
var i: usize = input.len - 1;
while (i > 0) : (i -= 1) {
if (input[i] == '=') {
n -= 1;
} else {
break;
}
}
return n;
}
这里也是差不多,需要考虑输入字符数小于4的情况。同时还需要处理填充字符=,每个填充字符都意味着少了1个字节,所以我们需要在计算结果中减去相应的字节数。std.math.divFloor函数用于计算向下取整的结果。
编码逻辑实现#
我们知道原理后,剩下的编码操作就只是搬砖了。本质上就是每次处理3个字节(或者说3个ascii字符,即3个8位二进制),然后转换为4个Base64字符,直到处理完输入字符。
编码的核心就是将3个8位二进制转换为4个6位二进制,然后根据Base64编码表获取对应的字符。
具体怎么转换?位运算,比如第一个字节的前6位直接作为第一个6位二进制,第一个后2位和第二个字节的前4位组合成第二个6位二进制,依次类推。
这里涉及到两个位运算,一个是右移操作>>,另一个是按位与操作&。举个例子,我们使用右移操作将第一个字节右移2位,就得到了第一个字节的前6位,然后我们将原始的第一个字节再和0b00000011按位与操作,就得到了第一个字节的后2位,接着加上第二个字节右移4位的结果,就得到了第二个6位二进制。
同时注意一些边界条件,比如输入字符数不是3的倍数时的处理。
- 当尾部的输入字符为2个时,我们需要在编码后的字符串末尾添加1个填充字符
=。 - 当尾部的输入字符为1个时,我们需要在编码后的字符串末尾添加2个填充字符
=。
编码操作涉及到内存分配,所以编码函数需要外部传入一个内存分配器。
代码如下:
// encode input bytes to base64 string
pub fn encode(
self: Base64,
allocator: std.mem.Allocator,
input: []const u8,
) ![]u8 {
if (input.len == 0) {
return "";
}
const n_output: usize = try _cal_encode_length(input);
var out = try allocator.alloc(u8, n_output);
var buf = [3]u8{ 0, 0, 0 };
var count: u8 = 0;
var iout: usize = 0;
for (input, 0..) |_, i| {
buf[count] = input[i];
count += 1;
if (count == 3) {
// 三字符窗口
out[iout] = self._char_at(buf[0] >> 2);
out[iout + 1] = self._char_at(((buf[0] & 0b00000011) << 4) + (buf[1] >> 4));
out[iout + 2] = self._char_at(((buf[1] & 0b00001111) << 2) + (buf[2] >> 6));
out[iout + 3] = self._char_at(buf[2] & 0b00111111);
iout += 4;
count = 0;
}
}
// 单字符窗口判断
if (count == 1) {
out[iout] = self._char_at(buf[0] >> 2);
out[iout + 1] = self._char_at((buf[0] & 0b00000011) << 4);
// 补齐
out[iout + 2] = '=';
out[iout + 3] = '=';
}
// 双字符窗口判断
if (count == 2) {
out[iout] = self._char_at(buf[0] >> 2);
out[iout + 1] = self._char_at(((buf[0] & 0b00000011) << 4) + (buf[1] >> 4));
out[iout + 2] = self._char_at((buf[1] & 0b00001111) << 2);
// 补齐
out[iout + 3] = '=';
}
return out;
}
每次循环处理3个字节,将编码结果保存至buf数组中,处理完3个字节后,将结果写入输出数组out中。count变量就用于记录当前处理了多少个字节。iout变量用于记录输出数组的写入位置。最后判断count的值,处理剩余的1或2个字节,并添加相应的填充字符=。
上述的一些位移操作,各位自己拿草稿纸画一下就明白了。
解码逻辑实现#
解码操作和编码操作类似,本质上就是每次处理4个Base64字符,然后转换为3个字节,直到处理完输入的Base64编码字符串。
有几个不同的点:
- 要将输入的
Base64字符转换为对应的索引值(6位二进制) - 位移操作不同,因为是逆操作。比如获取第一个字节时,需要将第一个6位二进制左移2位,然后加上第二个6位二进制右移4位的结果(此时就通过2个6位二进制字节拼成1个8位二进制字节了)。接着第二位字节就是原始的第二个6位二进制左移4位,加上第三个6位二进制右移2位的结果,依次类推。
- 一些边界条件不同,因为这次是4个Base64字符转换为3个字节。所以一些变量的判断也不同。
- 记得忽略填充字符
=,因为它们不参与解码。(因为它没有意义)
// decode base64 string to bytes
pub fn decode(
self: Base64,
allocator: std.mem.Allocator,
input: []const u8,
) ![]u8 {
if (input.len == 0) {
return "";
}
const n_output: usize = try _cal_decode_length(input);
var out = try allocator.alloc(u8, n_output);
var count: u8 = 0;
var iout: u64 = 0;
var buf = [4]u8{ 0, 0, 0, 0 };
for (0..input.len) |i| {
buf[count] = self._char_index(input[i]);
count += 1;
if (count == 4) {
out[iout] = (buf[0] << 2) + (buf[1] >> 4);
if (buf[2] != 64) {
out[iout + 1] = (buf[1] << 4) + (buf[2] >> 2);
}
if (buf[3] != 64) {
out[iout + 2] = (buf[2] << 6) + buf[3];
}
iout += 3;
count = 0;
}
}
return out;
}
buf数组上限为4,因为每次处理4个Base64字符。count变量用于记录当前处理了多少个Base64字符,当其为4时,说明我们需要将里面的元素进行解码操作了。iout变量用于记录输出数组的写入位置。
还记得前面的_char_index方法返回64代表着什么吗?就是=。需要忽略它。
Base64字符串,每4位一组,只有其最后两位可能是填充字符=,所以我们只需要在解码时判断buf[2]和buf[3]是否为64即可。
测试代码示例#
pub fn main() !void {
const b = Base64.init();
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
const allocator = gpa.allocator();
// encode
const encoded = try b.encode(allocator, "Hello");
defer allocator.free(encoded);
std.debug.print("Encoded: {s}\n", .{encoded});
// decode
const decoded = try b.decode(allocator, encoded);
defer allocator.free(decoded);
std.debug.print("Decoded: {s}\n", .{decoded});
}
终端输出:
$ ./hello
Encoded: SGVsbG8=
Decoded: Hello