Technical Report on Secure Truncation with Applications to LLM Quantization

Technical Report on Secure Truncation with Applications to LLM Quantization Zyllion

Abstract Over the last decade, machine learning has gained significant attention and is now being used extensively in practice. Quantization, especially for large language models, has been a useful technique to increase speed and reduce memory footprint, making these techniques more accessible and deployable in resource-constrained environments. The popularity of machine learning and the need for data privacy has led to the development of privacy-preserving machine learning systems, many of which use secure multi-party computation. The arithmetic required for machine learning in secure multi-party computation typically relies on fixed-point arithmetic, as floating points are expensive in practice. Fixed-point arithmetic employs truncation for secure multiplication and other non-linear operations. We propose a new framework to securely truncate with perfect security and low memory footprint, which is especially efficient when the bit-width is low. For 8 bit f ields, this reduces communication complexity by 61% and memory usage by 83%. 1 2 TRANSFORMER ARCHITECTURE 1 Introduction Machine learning, and in particular deep learning, has emerged as a transformative technology with applications spanning various domains, from healthcare and finance to autonomous vehicles and natural language processing such as large language models (LLM) [27, 10]. The ability to analyze vast datasets and make data-driven predictions has revolutionized industries and enhanced decision-making processes [18, 21]. However, as the adoption of machine learning continues to grow, so do the concerns related to data privacy and security [22, 26]. The need to strike a balance between reaping the benefits of machine learning and safeguarding individual and organizational privacy has given rise to the field of privacy-preserving machine learning [24, 17]. Privacy-preserving machine learning using secure multi-party computation (MPC) has gained significant attention in recent years due to the property that sensitive information is not disclosed during the computation. MPC techniques include garbled circuits [2, 30, 31], function secret sharing [34, 16], fully homomorphic encryption based MPC [25, 35], and linear secret sharing schemes (LSSS) [19, 11, 5]. Most of these methods rely on 2 party computation. Function secret sharing for instance, has great applications and is very efficient for 2 party computation. For example, Sigma [16] makes LLM inference fast and practical. If we want to scale beyond 2 parties, however, some of these techniques are no longer efficient. Quantization is a well known technique in deep learning for decreasing memory footprint as well as increasing computational performance [36, 14, 9]. The basic idea behind quantiza tion is to represent the weights and activations of a neural network using a lower number of bits, such as 4 or 8 bit integers, rather than the standard 32-bit or 64-bit floating-point pre cision. This can lead to significant reductions in both model size and memory requirements, making it more feasible to deploy models on resource-constrained devices, such as mobile phones or edge devices. Secure computation gives rise to significant overhead, which is also reduced by quantization [6]. Truncation is a fundamental building block for secure integer and fixed-point arithmetic [3, 4, 12] that is needed in privacy-preserving machine learning. Typically, after a multiplication or an inner product fixed-point numbers, one must employ truncation to reduce its precision. Therefore, truncation is a basic building block for efficient quantization. Typically, general-purpose truncation protocols for fixed-point arithmetic, as discussed in [3], need a representation with an augmented number of bits to ensure statistical privacy. This situation undermines the idea of leveraging the benefits of quantized models, given that the size of the statistical parameter often exceeds the quantized size. However, there exist techniques for truncation in rings compatible with a quantized approach [6]. In this work, we introduce a general multi-party perfectly secure truncation protocol designed for small fields. As a consequence, we reduce the representation cost observed in previous works [3] incurred by statistical security parameter. 2 Transformer Architecture The transformer architecture is a type of neural network architecture [33] that uses attention mechanisms, which allows the model to focus on different parts of the input sequence when generating an output. It has become a foundational model for various natural language 2 3 DATAREPRESENTATION processing (NLP) tasks and has been widely adopted in machine translation, text generation, and other sequence-to-sequence tasks. The transformer architecture is a fundamental building block for LLMs. Several different parts of an LLM architecture require the use of truncation. We now briefly describe them. Transformers process sequences of tokens, where each token is typically a word or subword, and represent them as one-dimensional vectors of size d. These vectors are often referred to as embeddings. Next, transformers rely on self-attention, which is a mechanism that allows a system to weigh the importance of different parts of an input sequence when making predictions. It takes a query matrix Q and a set of key-value matrix pairs K,V and produces an output as follows: Attention(Q,K,V ) = softmax QKT/ √ d V To compute matrix multiplication we run many inner products that need to be followed by truncation. For softmax, one could make use of exponential calculated by approximating the limit [19]: limn→∞(1 + x n)n, followed by division using Newton-Raphson approximation [20]. Similarly, the inverse square root could also be calculated using Newton-Raphson method [20]. These all require multiplications followed by truncation. Next in the transformer architecture comes the feed forward neural network. These are computed using an inner product followed by an activation function. ReLU or GeLU activations are commonly used. ReLU can be implemented with only a comparison, while GeLU can be approximated by a ReLU plus a Taylor expansion of the difference around zero [16]. Finally, layer normalization uses inverse square root. Similarly, these elements require multiplications followed by truncation. 3 Data Representation We follow the approach from previous works [3, 6] and we consider secure computation with binary values encoded as 0 and 1, signed integers and fixed-point rationals. Signed integer values are defined by Z⟨k⟩ = {¯ x ∈ Z| −2k−1 ≤ ¯ x < 2k−1}, and encode values using encode : Z⟨k⟩ − →Zq as encode(¯ x) = ¯ x mod q, where q > 2k. Fixed-point rational numbers are encoded as integers ¯ x = ˜ x2f ∈ Z⟨k⟩, defined by Q⟨k,f⟩ = {˜ x ∈ Q|˜ x = ¯ x · 2−f, ¯ x ∈ Z⟨k⟩}. Fixed-point multiplication in MPC requires q > 22k to prevent overflows. Quantization from a floating-point number x to a fixed-point number ˜ x can be achieved by picking a precision parameter f, then multiplying the number by 2f and rounding down to the nearest integer. ˜ x =⌊x·2f⌋ The fixed-point number can be converted back simply via a division by 2f. Note that accuracy may be lost during the quantization, hence the converted back number might loose precision. x ≈ ˜ x 2f This is akin to uniform quantization in machine learning given by the formula Q(x) = int(x/S)+Z [14], where Q is the quantization function, x is the real valued input, S is the 3 5 TRUNCATION IN SMALL FIELD real valued scaling factor, int maps input to the nearest integer value, and Z is the integer zero point. 4 Trust Model We will set up a truncation protocol which is secure in the same setting as the underlying LSSS-based MPC protocol. Examples include: • BGW/GRR with t<n/2 passive corruptions [1, 13] • ATLAS with t < n/2 active corruptions [15] For performance calculations, we work in the passive adversary setup, where there are t <n/2 corrupt parties. Although our protocol works for any linear secret sharing scheme, we assume all values are secret shared using Shamir’s secret sharing scheme [32] over a prime field Fq. Further, perfect security against an active adversary that can corrupt t < n/3 parties can be achieved using verifiable secret sharing (VSS) [13] and maliciously secure protocols [29] such as BGW. 5 Truncation in Small Field Various forms of overhead occur when we use MPC protocols. Specifically, when performing multiplication, one needs to double the field size to prevent overflows. For example, multiplication of two 8 bit integers may lead to a 16 bit integer. Additionally, when implementing truncation, extra statistical security bits λ- typically around 40 bits- are crucial for the single-round probabilistic truncation protocol to work securely as outlined in [3]. Further, we either need an even larger modulus for pseudo random integer generation [7] or to generate the random integer using random bits, which requires a random number along with secure multiplication [8]. Working with 128+λ bits (128 to prevent multiplication overflow and 40 for statistical security) for 64-bit arithmetic does not seem an excessive overhead. However, for smaller sizes it defeats the purpose of quantization. For example, for 8-bit arithmetic, this turns 16-bit fields into 56-bit fields, more than 3 times as many bits! We can, however, get rid of the statistical bits in exchange for a comparison protocol to calculate bitwise less than [23, 3, 8] as in πTRUNC-PR below. This protocol works in a finite f ield Fq, where l = ⌈log2q⌉ > 2k for secret input ˜ x ∈ Q⟨k,f⟩. It takes the input in secret shared signed integer form [x] and clear value m returning a sharing of ⌊x/2m⌋ + u, where there is a random term u ∈ {0,1} [8, 28]. On a high level, we do this by first generating a random value r that is bitwise shared as ([r0],...,[rl−1]) with ri ∈ {0,1}. Then we reveal the secret masked by our random value. Using a bitwise comparison, we check if there was an overflow and we use this in calculating the remainder of the revealed value by 2m. We then use this with the first m bits of the bitwise random number to securely calculate the truncated value. 4 5 TRUNCATION IN SMALL FIELD Quantized Probabilistic Truncation, πTRUNC-PR. [x/2m] ← πTRUNC-PR([x],m)

  1. ([r0],...,[rl−1]) ← FRAN-BITWISE() (a) [r′] ← m−1 i=0 2i · [r]i (b) [r] ← l−1 i=0 2i · [r]i

  2. c ←FREVEAL(2k−1 +[x] +[r])

  3. [c < r] ← FBIT-LESS-THAN(c,([r0],...,[rl−1]))

  4. [c′] ← (1 −[c < r])·(c mod 2m)+[c < r]·(c+q mod 2m)

  5. [w] ← [x]+[r′] −[c′]

  6. [x/2m] ← [w]·(2m)−1

  7. OUTPUT [x/2m] Correctness: Note that the secret ˜ x ∈ Q⟨k,f⟩ is first translated into ¯ x ∈ Z⟨k⟩ as ¯ x = ˜ x2f which is encoded as x = ¯ x mod q. Let b = (2k−1 + ¯ x) mod q, hence 0 ≤ b < 2k. Next b′ = b mod 2m, x′ = ¯ x mod 2m, then b′ = x′ for any 0 < m < k. In the protocol, we f irst generate a bitwise shared random number in the field during preprocessing. This vector of sharings of bits we denote ([r0],...,[rl−1]), and we set r ← l−1 i=0 2i · ri, where ri is the i-th bit of r and l = ⌈log2q⌉ > 2k. We set r′ = r mod 2m. We add r to the secret as well as the 2k−1 term to account for negative numbers, revealing the result c = b + r mod q. Now, if c ≥ r, then c = b+r, else c = b+r −q, thus c = b+r −(c < r)·q. Thus, we compare the clear number c to the bitwise shared random number r to check for overflows. Then, we take the modulo 2m in the clear as follows, c′ = c mod 2m if c ≥ r, and else if c < r then c′ = c+q mod 2m; or in a single equation c′ = (1 −(c < r))·(c mod 2m)+(c <r)·(c+q mod 2m). If c ≥ r, then c′ = (b+r) mod 2m = b′+r′−2m·u, where u ∈ {0,1}. Else if c < r, then there was an overflow, which means b+r > q, hence b +r =c+q. Then as before, c′ = (c+q) mod 2m = b′+r′−2m·u, where u ∈ {0,1}. We get x′ = b′ = c′−r′+2m·u, so that w = x+r′−c′ = x−x′−2m·u, but x−x′ = 2m⌊x/2m⌋. Thus, w ·(2m)−1 = (2m⌊x/2m⌋−2m ·u)·(2m)−1 = ⌊x/2m⌋−u. This is the probabilistic truncation result we wanted. Note that Pr(u = 1) = Pr(x′+r′ ≥ 2m), which is the rounding property. Security: We can leak information only in steps where a shared value is reconstructed. These values are of the form y = x +r, where the secret is x and the random value r is uniformly random in the field. Since r is uniformly random, x + r is a one-time padding of secret x, therefore is perfectly secure. Since the sub-protocols FRAN-BITWISE [28] and FBIT-LESS-THAN [23] are perfectly secure, we conclude that πTRUNC-PR is perfectly secure. Performance: Let us assume that the random bits are produced in the pre-processing 5 REFERENCES phase. Regarding the online round complexity, we have log2log2q − 1 rounds for the comparison [23] protocol plus a round for the reveal operation, which results in a total of log2 log2 q online rounds. Regarding the total communication complexity, we need a bitwise shared random number plus l/2 multiplications offline, l multiplications online for the comparison protocol and a reveal operation, resulting in a total communication complexity of 2.5l + 1 field elements sent and received between any two nodes. We are assuming GRR style multiplication of Shamir shares where each node shares their share and sends a share to every other party [13]. In contrast, the probabilistic truncation protocol with statistical security [3] needs l random bits plus a reveal operation in the online phase. This is summarized in Table 1. It is interesting to see that our truncation protocol is better communication-wise than [3] only for small elements. For 32 bits, the statistical probabilistic truncation already requires less bandwidth. Specifically, the inflection point is for 27 bits for λ =40 bits. This motivates the use of quantization to operate below this inflection point. Notice that regardless of the number of bits, we have less memory requirements than [3]. Our Quantized Statistical Probabilistic Probabilistic Truncation Field Elements Sent Truncation [3] Ratio 2.5l + 1 Bits Sent 2.5l2 + l Memory per Element l +1 l2 +(λ+1)l +λ 2.5l2+l l 2.5l+1 l+1 l2+(λ+1)l+λ l +λ l l+λ Bits Sent, l = 8 168 432 0.389 Bits Memory, l = 8 8 48 0.167 Bits Sent, l = 16 656 952 0.689 Bits Memory, l = 16 16 56 0.286 Bits Sent, l = 32 2592 2376 1.090 Bits Memory, l = 32 32 72 Table 1: Comparison of communication and memory complexity for l bit prime. Communication is between any two nodes. We use λ = 40 bits for calculations. 6 Conclusion 0.444 We have presented a new MPC truncation protocol for linear secret sharing schemes that works well in small fields, therefore is useful for quantization. This has implications in machine learning applications, specifically in LLM quantization, where model weights might be stored and computed in smaller fields. By working in a small field without a statistical security parameter, we can load larger models into memory using quantization. References [1] Gilad Asharov and Yehuda Lindell. “A full proof of the BGW protocol for perfectly secure multiparty computation”. In: Journal of Cryptology 30.1 (2017), pp. 58–151.

Last updated