a go dns packet parser

refractor code to be more standard

+1385 -546
+30 -31
domain_name.go
··· 2 2 3 3 import ( 4 4 "encoding/binary" 5 + "fmt" 5 6 "strings" 6 7 ) 7 8 8 - // decode_domain decodes a domain name from a buffer starting at offset. 9 + // decodeDomain decodes a domain name from a buffer starting at offset. 9 10 // It returns the domain name along with the offset and error. 10 - func decode_domain(buf []byte, offset int) (string, int, error) { 11 + func decodeDomain(buf []byte, offset int) (string, int, error) { 11 12 var builder strings.Builder 12 13 firstLabel := true 13 14 14 - seen_offsets := make(map[int]struct{}) 15 + seenOffsets := make(map[int]struct{}) 15 16 finalOffsetAfterJump := -1 16 17 17 18 currentOffset := offset 18 19 19 20 for { 20 - if _, found := seen_offsets[currentOffset]; found { 21 + if _, found := seenOffsets[currentOffset]; found { 21 22 return "", len(buf), &DomainCompressionError{} 22 23 } 23 - seen_offsets[currentOffset] = struct{}{} 24 + seenOffsets[currentOffset] = struct{}{} 24 25 25 26 length, nextOffsetAfterLen, err := getU8(buf, currentOffset) 26 27 if err != nil { 27 - return "", len(buf), err 28 + return "", len(buf), fmt.Errorf("failed to read domain label length: %w", err) 28 29 } 29 30 30 31 if length == 0 { ··· 35 36 if (length & 0xC0) == 0xC0 { 36 37 sec, nextOffsetAfterPtr, err := getU8(buf, nextOffsetAfterLen) 37 38 if err != nil { 38 - return "", len(buf), err 39 + return "", len(buf), fmt.Errorf("failed to read domain compression pointer offset byte: %w", err) 39 40 } 40 41 41 42 jumpTargetOffset := int(length&0x3F)<<8 | int(sec) ··· 43 44 if jumpTargetOffset >= len(buf) { 44 45 return "", len(buf), &BufferOverflowError{Length: len(buf), Offset: jumpTargetOffset} 45 46 } 46 - if _, found := seen_offsets[jumpTargetOffset]; found { 47 + if _, found := seenOffsets[jumpTargetOffset]; found { 47 48 return "", len(buf), &DomainCompressionError{} 48 49 } 49 50 ··· 61 62 62 63 labelBytes, nextOffsetAfterLabel, err := getSlice(buf, nextOffsetAfterLen, int(length)) 63 64 if err != nil { 64 - return "", len(buf), err 65 + return "", len(buf), fmt.Errorf("failed to read domain label data: %w", err) 65 66 } 66 67 67 68 if !firstLabel { ··· 82 83 return builder.String(), finalReadOffset, nil 83 84 } 84 85 85 - // encode_domain returns the bytes of the input bytes appened with the encoded domain name. 86 - func encode_domain(bytes []byte, domain_name string, offsets *map[string]uint16) []byte { 87 - if domain_name == "." || domain_name == "" { 88 - return append(bytes, 0) 86 + // encodeDomain returns the bytes of the input bytes appened with the encoded domain name. 87 + func encodeDomain(bytes []byte, domainName string, offsets *map[string]uint16) ([]byte, error) { 88 + if domainName == "." || domainName == "" { 89 + return append(bytes, 0), nil 89 90 } 90 91 91 - clean_domain := strings.TrimSuffix(domain_name, ".") 92 - if clean_domain == "" { 93 - return append(bytes, 0) 92 + cleanDomain := strings.TrimSuffix(domainName, ".") 93 + if cleanDomain == "" { 94 + return append(bytes, 0), nil 94 95 } 95 96 96 97 start := 0 97 - for start < len(clean_domain) { 98 - suffix := clean_domain[start:] 98 + for start < len(cleanDomain) { 99 + suffix := cleanDomain[start:] 99 100 100 101 if offset, found := (*offsets)[suffix]; found { 101 - if offset > 0x3FFF { 102 - end := strings.IndexByte(suffix, '.') 103 - if end == -1 { 104 - end = len(suffix) 105 - } 106 - } else { 107 - pointer := 0xC000 | offset 108 - return binary.BigEndian.AppendUint16(bytes, pointer) 109 - } 102 + pointer := 0xC000 | offset 103 + bytes = binary.BigEndian.AppendUint16(bytes, pointer) 104 + return bytes, nil 110 105 } 111 106 112 107 currentPos := uint16(len(bytes)) ··· 116 111 117 112 end := strings.IndexByte(suffix, '.') 118 113 var label string 119 - nextStart := len(clean_domain) 114 + nextStart := len(cleanDomain) 120 115 121 116 if end == -1 { 122 117 label = suffix ··· 130 125 labelBytes := []byte(label) 131 126 132 127 if len(labelBytes) > 63 { 133 - // XXX: maybe should return an error 134 - labelBytes = labelBytes[:63] 128 + return nil, &InvalidLabelError{Length: int(len(labelBytes))} 129 + } 130 + 131 + if len(labelBytes) == 0 && start < len(cleanDomain) { 132 + return nil, &InvalidLabelError{Length: 0} 135 133 } 136 134 137 135 bytes = append(bytes, byte(len(labelBytes))) 138 136 bytes = append(bytes, labelBytes...) 139 137 } 140 138 141 - return append(bytes, 0) 139 + bytes = append(bytes, 0) 140 + return bytes, nil 142 141 }
+143 -22
domain_test.go
··· 1 1 package magna 2 2 3 3 import ( 4 + "errors" 4 5 "testing" 5 6 6 7 "github.com/stretchr/testify/assert" ··· 9 10 func BenchmarkDecodeDomainSimple(b *testing.B) { 10 11 input := []byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0} 11 12 for i := 0; i < b.N; i++ { 12 - _, _, _ = decode_domain(input, 0) 13 + _, _, _ = decodeDomain(input, 0) 13 14 } 14 15 } 15 16 ··· 21 22 offset := 13 22 23 b.ResetTimer() 23 24 for i := 0; i < b.N; i++ { 24 - _, _, _ = decode_domain(input, offset) 25 + _, _, _ = decodeDomain(input, offset) 25 26 } 26 27 } 27 28 ··· 31 32 out := make([]byte, 0, 64) 32 33 b.ResetTimer() 33 34 for i := 0; i < b.N; i++ { 34 - _ = encode_domain(out[:0], domain, &offsets) 35 + _, _ = encodeDomain(out[:0], domain, &offsets) 35 36 for k := range offsets { 36 37 delete(offsets, k) 37 38 } ··· 45 46 out := make([]byte, 0, 128) 46 47 b.ResetTimer() 47 48 for i := 0; i < b.N; i++ { 48 - tempOut := encode_domain(out[:0], domain1, &offsets) 49 - _ = encode_domain(tempOut, domain2, &offsets) 49 + tempOut, _ := encodeDomain(out[:0], domain1, &offsets) 50 + _, _ = encodeDomain(tempOut, domain2, &offsets) 50 51 for k := range offsets { 51 52 delete(offsets, k) 52 53 } ··· 61 62 expectedDomain string 62 63 expectedOffset int 63 64 expectedError error 65 + errorCheck func(t *testing.T, err error) 64 66 }{ 65 67 { 66 68 name: "Simple domain", ··· 83 85 expectedDomain: "", 84 86 expectedOffset: 2, 85 87 expectedError: &InvalidLabelError{Length: 64}, 88 + errorCheck: func(t *testing.T, err error) { 89 + var target *InvalidLabelError 90 + assert.True(t, errors.As(err, &target)) 91 + assert.Equal(t, 64, target.Length) 92 + }, 86 93 }, 87 94 { 88 95 name: "Compression loop", ··· 90 97 expectedDomain: "", 91 98 expectedOffset: 4, 92 99 expectedError: &DomainCompressionError{}, 100 + errorCheck: func(t *testing.T, err error) { 101 + assert.IsType(t, &DomainCompressionError{}, err) 102 + }, 93 103 }, 94 104 { 95 105 name: "Truncated input", ··· 97 107 expectedDomain: "", 98 108 expectedOffset: 3, 99 109 expectedError: &BufferOverflowError{Length: 3, Offset: 4}, 110 + errorCheck: func(t *testing.T, err error) { 111 + var target *BufferOverflowError 112 + assert.True(t, errors.As(err, &target), "Expected BufferOverflowError") 113 + if target != nil { 114 + assert.Equal(t, 3, target.Length) 115 + assert.Equal(t, 1+3, target.Offset) 116 + } 117 + assert.Contains(t, err.Error(), "failed to read domain label data") 118 + }, 100 119 }, 101 120 } 102 121 103 122 for _, tt := range tests { 104 123 t.Run(tt.name, func(t *testing.T) { 105 - domain, offset, err := decode_domain(tt.input, tt.offset) 124 + domain, offset, err := decodeDomain(tt.input, tt.offset) 125 + 126 + t.Logf("Test: %s, Input: %x, OffsetIn: %d => Domain: '%s', OffsetOut: %d, Err: %v", tt.name, tt.input, tt.offset, domain, offset, err) 127 + 128 + if tt.expectedError != nil { 129 + assert.Error(t, err, "Expected an error but got nil") 130 + if tt.errorCheck != nil { 131 + tt.errorCheck(t, err) 132 + } else { 133 + assert.IsType(t, tt.expectedError, err, "Error type mismatch") 134 + } 135 + } else { 136 + assert.NoError(t, err, "Expected no error but got one") 137 + } 106 138 107 - t.Log(tt.name) 108 - assert.Equal(t, tt.expectedError, err) 109 - assert.Equal(t, tt.expectedDomain, domain) 110 - assert.Equal(t, tt.expectedOffset, offset) 139 + assert.Equal(t, tt.expectedDomain, domain, "Domain mismatch") 140 + if tt.expectedError == nil { 141 + assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch") 142 + } 111 143 }) 112 144 } 113 145 } 114 146 115 147 func TestEncodeDomain(t *testing.T) { 116 148 tests := []struct { 117 - name string 118 - input string 119 - offsets map[string]uint16 120 - expected []byte 121 - newOffsets map[string]uint16 149 + name string 150 + input string 151 + initialBuf []byte 152 + offsets map[string]uint16 153 + expected []byte 154 + expectedErr error 155 + newOffsets map[string]uint16 122 156 }{ 123 157 { 124 158 name: "Simple domain", 125 159 input: "example.com", 160 + initialBuf: []byte{}, 126 161 offsets: make(map[string]uint16), 127 162 expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}, 128 163 newOffsets: map[string]uint16{"example.com": 0, "com": 8}, 129 164 }, 130 165 { 131 - name: "Domain with existing offset", 166 + name: "Domain with existing offset for compression", 132 167 input: "test.example.com", 168 + initialBuf: []byte{}, 133 169 offsets: map[string]uint16{"example.com": 10}, 134 170 expected: []byte{4, 't', 'e', 's', 't', 0xC0, 0x0A}, 135 171 newOffsets: map[string]uint16{"test.example.com": 0, "example.com": 10}, ··· 137 173 { 138 174 name: "Multiple subdomains", 139 175 input: "a.b.c.d", 176 + initialBuf: []byte{}, 140 177 offsets: make(map[string]uint16), 141 178 expected: []byte{1, 'a', 1, 'b', 1, 'c', 1, 'd', 0}, 142 179 newOffsets: map[string]uint16{"a.b.c.d": 0, "b.c.d": 2, "c.d": 4, "d": 6}, 143 180 }, 181 + { 182 + name: "Root domain", 183 + input: ".", 184 + initialBuf: []byte{}, 185 + offsets: make(map[string]uint16), 186 + expected: []byte{0}, 187 + newOffsets: map[string]uint16{}, 188 + }, 189 + { 190 + name: "Empty domain", 191 + input: "", 192 + initialBuf: []byte{}, 193 + offsets: make(map[string]uint16), 194 + expected: []byte{0}, 195 + newOffsets: map[string]uint16{}, 196 + }, 197 + { 198 + name: "Label too long", 199 + input: "labeltoolonglabeltoolonglabeltoolonglabeltoolonglabeltoolonglabeltoolong.com", 200 + initialBuf: []byte{}, 201 + offsets: make(map[string]uint16), 202 + expected: nil, 203 + expectedErr: &InvalidLabelError{Length: 72}, 204 + newOffsets: map[string]uint16{}, 205 + }, 206 + { 207 + name: "Empty label inside domain", 208 + input: "example..com", 209 + initialBuf: []byte{}, 210 + offsets: make(map[string]uint16), 211 + expected: nil, 212 + expectedErr: &InvalidLabelError{Length: 0}, 213 + newOffsets: map[string]uint16{}, 214 + }, 215 + { 216 + name: "Append to existing buffer", 217 + input: "example.com", 218 + initialBuf: []byte{0xAA, 0xBB}, 219 + offsets: make(map[string]uint16), 220 + expected: []byte{0xAA, 0xBB, 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}, 221 + newOffsets: map[string]uint16{"example.com": 2, "com": 10}, 222 + }, 144 223 } 145 224 146 225 for _, tt := range tests { 147 226 t.Run(tt.name, func(t *testing.T) { 148 - result := encode_domain([]byte{}, tt.input, &tt.offsets) 149 - assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output") 150 - assert.Equal(t, tt.newOffsets, tt.offsets, "Offsets map does not match expected state") 227 + currentOffsets := make(map[string]uint16) 228 + for k, v := range tt.offsets { 229 + currentOffsets[k] = v 230 + } 231 + 232 + result, err := encodeDomain(tt.initialBuf, tt.input, &currentOffsets) 233 + 234 + if tt.expectedErr != nil { 235 + assert.Error(t, err, "Expected an error but got nil") 236 + assert.IsType(t, tt.expectedErr, err, "Error type mismatch") 237 + if expectedILE, ok := tt.expectedErr.(*InvalidLabelError); ok { 238 + actualILE := &InvalidLabelError{} 239 + if assert.True(t, errors.As(err, &actualILE)) { 240 + assert.Equal(t, expectedILE.Length, actualILE.Length) 241 + } 242 + } 243 + } else { 244 + assert.NoError(t, err, "Expected no error but got one") 245 + assert.Equal(t, tt.expected, result, "Encoded domain does not match expected output") 246 + assert.Equal(t, tt.newOffsets, currentOffsets, "Offsets map does not match expected state") 247 + } 151 248 }) 152 249 } 153 250 } ··· 158 255 0x03, 0x63, 0x6f, 0x6d, 0x00, 159 256 }, 160 257 { 161 - 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xC0, 0x00, 258 + 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x01, 0x63, 0xc0, 0x00, 259 + }, 260 + { 261 + 0x03, 0x63, 0x6f, 0x6d, 0xc0, 0x00, 262 + }, 263 + { 264 + 0xc0, 0x00, 265 + }, 266 + { 267 + 0xc0, 0xff, 162 268 }, 163 269 { 164 - 0x03, 0x63, 0x6f, 0x6d, 0xC0, 0x00, 270 + 0x40, 271 + }, 272 + { 273 + 0x03, 0x63, 0x6f, 274 + }, 275 + { 276 + 0xc0, 165 277 }, 166 278 } 167 279 for _, tc := range testcases { 168 280 f.Add(tc) 169 281 } 170 282 f.Fuzz(func(t *testing.T, msg []byte) { 171 - decode_domain(msg, 0) 283 + _, _, err := decodeDomain(msg, 0) 284 + if err != nil { 285 + var bufErr *BufferOverflowError 286 + var labelErr *InvalidLabelError 287 + var compErr *DomainCompressionError 288 + 289 + if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr)) { 290 + t.Errorf("Fuzzing decodeDomain: unexpected error type %T: %v for input %x", err, err, msg) 291 + } 292 + } 172 293 }) 173 294 }
+24 -10
errors.go
··· 11 11 } 12 12 13 13 func (e *BufferOverflowError) Error() string { 14 - return fmt.Sprintf("magna: offset %d is past the buffer length %d", e.Offset, e.Length) 14 + return fmt.Sprintf("buffer overflow: attempted to read past buffer length %d at offset %d", e.Length, e.Offset) 15 + } 16 + 17 + func (e *BufferOverflowError) Is(target error) bool { 18 + _, ok := target.(*BufferOverflowError) 19 + return ok 15 20 } 16 21 17 22 // InvalidLabelError represents an error when an invalid label length is encountered. ··· 20 25 } 21 26 22 27 func (e *InvalidLabelError) Error() string { 23 - return fmt.Sprintf("magna: received invalid label length %d", e.Length) 28 + if e.Length > 63 { 29 + return fmt.Sprintf("invalid domain label: length %d exceeds maximum 63", e.Length) 30 + } 31 + if e.Length == 0 { 32 + return "invalid domain label: zero length label encountered" 33 + } 34 + 35 + // XXX: this should be unreachable 36 + return fmt.Sprintf("invalid domain label: unexpected length %d", e.Length) 37 + } 38 + 39 + func (e *InvalidLabelError) Is(target error) bool { 40 + _, ok := target.(*InvalidLabelError) 41 + return ok 24 42 } 25 43 26 44 // DomainCompressionError represents an error related to domain compression. 27 45 type DomainCompressionError struct{} 28 46 29 47 func (e *DomainCompressionError) Error() string { 30 - return "magna: loop detected in domain compression" 48 + return "invalid domain compression: pointer loop detected" 31 49 } 32 50 33 - // MagnaError represents a generic error with a custom message. 34 - type MagnaError struct { 35 - Message string 36 - } 37 - 38 - func (e *MagnaError) Error() string { 39 - return fmt.Sprintf("magna: %s", e.Message) 51 + func (e *DomainCompressionError) Is(target error) bool { 52 + _, ok := target.(*DomainCompressionError) 53 + return ok 40 54 }
+8 -28
errors_test.go
··· 1 1 package magna 2 2 3 3 import ( 4 - "fmt" 5 4 "testing" 6 5 7 6 "github.com/stretchr/testify/assert" ··· 14 13 offset int 15 14 expected string 16 15 }{ 17 - {"PositiveOffset", 10, 15, "magna: offset 15 is past the buffer length 10"}, 18 - {"ZeroLengthBuffer", 0, 5, "magna: offset 5 is past the buffer length 0"}, 19 - {"NegativeOffset", 10, -1, "magna: offset -1 is past the buffer length 10"}, 20 - {"EqualOffset", 10, 10, "magna: offset 10 is past the buffer length 10"}, 16 + {"PositiveOffset", 10, 15, "buffer overflow: attempted to read past buffer length 10 at offset 15"}, 17 + {"ZeroLengthBuffer", 0, 5, "buffer overflow: attempted to read past buffer length 0 at offset 5"}, 18 + {"NegativeOffset", 10, -1, "buffer overflow: attempted to read past buffer length 10 at offset -1"}, 19 + {"EqualOffset", 10, 10, "buffer overflow: attempted to read past buffer length 10 at offset 10"}, 21 20 } 22 21 23 22 for _, tt := range tests { ··· 34 33 length int 35 34 expected string 36 35 }{ 37 - {"LengthTooLarge", 64, "magna: received invalid label length 64"}, 38 - {"LengthZero", 0, "magna: received invalid label length 0"}, 39 - {"NegativeLength", -1, "magna: received invalid label length -1"}, 36 + {"LengthTooLarge", 64, "invalid domain label: length 64 exceeds maximum 63"}, 37 + {"LengthZero", 0, "invalid domain label: zero length label encountered"}, 38 + {"ValidLength", 30, "invalid domain label: unexpected length 30"}, 40 39 } 41 40 42 41 for _, tt := range tests { ··· 50 49 func TestDomainCompressionError(t *testing.T) { 51 50 t.Run("Standard", func(t *testing.T) { 52 51 err := &DomainCompressionError{} 53 - expected := "magna: loop detected in domain compression" 52 + expected := "invalid domain compression: pointer loop detected" 54 53 assert.Equal(t, expected, err.Error(), "Error() output mismatch") 55 54 }) 56 55 } 57 - 58 - func TestMagnaError(t *testing.T) { 59 - tests := []struct { 60 - name string 61 - message string 62 - expected string 63 - }{ 64 - {"EmptyMessage", "", "magna: "}, 65 - {"SimpleMessage", "test error", "magna: test error"}, 66 - {"MessageWithPunctuation", "error: invalid input!", "magna: error: invalid input!"}, 67 - } 68 - 69 - for _, tt := range tests { 70 - t.Run(tt.name, func(t *testing.T) { 71 - err := &MagnaError{Message: tt.message} 72 - assert.Equal(t, tt.expected, err.Error()) 73 - }) 74 - } 75 - }
+10 -7
header.go
··· 1 1 package magna 2 2 3 - import "encoding/binary" 3 + import ( 4 + "encoding/binary" 5 + "fmt" 6 + ) 4 7 5 8 // Decode decodes the header from the bytes. 6 9 func (h *Header) Decode(buf []byte, offset int) (int, error) { 7 10 var err error 8 11 h.ID, offset, err = getU16(buf, offset) 9 12 if err != nil { 10 - return len(buf), err 13 + return len(buf), fmt.Errorf("header decode: failed to read ID: %w", err) 11 14 } 12 15 13 16 flags, offset, err := getU16(buf, offset) 14 17 if err != nil { 15 - return len(buf), err 18 + return len(buf), fmt.Errorf("header decode: failed to read flags: %w", err) 16 19 } 17 20 18 21 h.QDCount, offset, err = getU16(buf, offset) 19 22 if err != nil { 20 - return len(buf), err 23 + return len(buf), fmt.Errorf("header decode: failed to read QDCount: %w", err) 21 24 } 22 25 23 26 h.ANCount, offset, err = getU16(buf, offset) 24 27 if err != nil { 25 - return len(buf), err 28 + return len(buf), fmt.Errorf("header decode: failed to read ANCount: %w", err) 26 29 } 27 30 28 31 h.NSCount, offset, err = getU16(buf, offset) 29 32 if err != nil { 30 - return len(buf), err 33 + return len(buf), fmt.Errorf("header decode: failed to read NSCount: %w", err) 31 34 } 32 35 33 36 h.ARCount, offset, err = getU16(buf, offset) 34 37 if err != nil { 35 - return len(buf), err 38 + return len(buf), fmt.Errorf("header decode: failed to read ARCount: %w", err) 36 39 } 37 40 38 41 h.QR = ((flags >> 15) & 0x01) == 1
+81 -60
header_test.go
··· 1 1 package magna 2 2 3 3 import ( 4 + "bytes" 4 5 "encoding/binary" 6 + "errors" 5 7 "testing" 6 8 7 9 "github.com/stretchr/testify/assert" 10 + "github.com/stretchr/testify/require" 8 11 ) 9 12 10 13 func TestHeaderDecode(t *testing.T) { ··· 14 17 expectedHeader Header 15 18 expectedOffset int 16 19 expectedErr error 20 + wantErrMsg string 17 21 }{ 18 22 { 19 23 name: "Valid header", ··· 37 41 expectedErr: nil, 38 42 }, 39 43 { 40 - name: "Insufficient buffer length", 44 + name: "Insufficient buffer length for Flags", 41 45 input: []byte{0x12, 0x34, 0x81}, 42 - expectedHeader: Header{}, 46 + expectedHeader: Header{ID: 0x1234}, 43 47 expectedOffset: 3, 44 - expectedErr: &BufferOverflowError{Length: 3, Offset: 3}, 48 + expectedErr: &BufferOverflowError{}, 49 + wantErrMsg: "header decode: failed to read flags", 45 50 }, 46 51 { 47 - name: "Invalid ID", 52 + name: "Insufficient buffer length for ID", 48 53 input: []byte{0x12}, 49 54 expectedHeader: Header{}, 50 55 expectedOffset: 1, 51 - expectedErr: &BufferOverflowError{Length: 1, Offset: 1}, 56 + expectedErr: &BufferOverflowError{}, 57 + wantErrMsg: "header decode: failed to read ID", 52 58 }, 53 59 { 54 60 name: "Missing QDCount", 55 61 input: []byte{0x12, 0x34, 0x81, 0x80, 0x00}, 56 - expectedHeader: Header{}, 62 + expectedHeader: Header{ID: 0x1234}, 57 63 expectedOffset: 5, 58 64 expectedErr: &BufferOverflowError{}, 65 + wantErrMsg: "header decode: failed to read QDCount", 59 66 }, 60 67 { 61 68 name: "Missing ANCount", 62 - input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01}, 63 - expectedHeader: Header{}, 64 - expectedOffset: 6, 69 + input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00}, 70 + expectedHeader: Header{ID: 0x1234, QDCount: 1}, 71 + expectedOffset: 7, 65 72 expectedErr: &BufferOverflowError{}, 73 + wantErrMsg: "header decode: failed to read ANCount", 66 74 }, 67 75 { 68 76 name: "Missing NSCount", 69 - input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02}, 70 - expectedHeader: Header{}, 71 - expectedOffset: 8, 77 + input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00}, 78 + expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2}, 79 + expectedOffset: 9, 72 80 expectedErr: &BufferOverflowError{}, 81 + wantErrMsg: "header decode: failed to read NSCount", 73 82 }, 74 83 { 75 84 name: "Missing ARCount", 76 - input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03}, 77 - expectedHeader: Header{}, 78 - expectedOffset: 10, 85 + input: []byte{0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00}, 86 + expectedHeader: Header{ID: 0x1234, QDCount: 1, ANCount: 2, NSCount: 3}, 87 + expectedOffset: 11, 79 88 expectedErr: &BufferOverflowError{}, 89 + wantErrMsg: "header decode: failed to read ARCount", 80 90 }, 81 91 } 82 92 ··· 86 96 offset, err := h.Decode(tt.input, 0) 87 97 88 98 if tt.expectedErr != nil { 89 - assert.Error(t, err) 90 - assert.IsType(t, tt.expectedErr, err) 99 + assert.Error(t, err, "Expected an error but got nil") 100 + 101 + assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr) 102 + 103 + if tt.wantErrMsg != "" { 104 + assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch") 105 + } 106 + 107 + assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error") 91 108 } else { 92 - assert.NoError(t, err) 93 - assert.Equal(t, tt.expectedHeader, *h) 109 + assert.NoError(t, err, "Expected no error but got one") 110 + 111 + assert.Equal(t, tt.expectedHeader, *h, "Header content mismatch") 112 + assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success") 94 113 } 95 - 96 - assert.Equal(t, tt.expectedOffset, offset) 97 114 }) 98 115 } 99 116 } ··· 160 177 } 161 178 162 179 h := &Header{} 163 - _, err := h.Decode(input, 0) 180 + offset, err := h.Decode(input, 0) 164 181 165 182 assert.NoError(t, err) 166 - assert.Equal(t, tt.expected.QR, h.QR) 167 - assert.Equal(t, tt.expected.OPCode, h.OPCode) 168 - assert.Equal(t, tt.expected.AA, h.AA) 169 - assert.Equal(t, tt.expected.TC, h.TC) 170 - assert.Equal(t, tt.expected.RD, h.RD) 171 - assert.Equal(t, tt.expected.RA, h.RA) 172 - assert.Equal(t, tt.expected.Z, h.Z) 173 - assert.Equal(t, tt.expected.RCode, h.RCode) 183 + assert.Equal(t, 12, offset, "Offset should be 12 after decoding full header") 184 + 185 + assert.Equal(t, tt.expected.QR, h.QR, "QR flag mismatch") 186 + assert.Equal(t, tt.expected.OPCode, h.OPCode, "OPCode mismatch") 187 + assert.Equal(t, tt.expected.AA, h.AA, "AA flag mismatch") 188 + assert.Equal(t, tt.expected.TC, h.TC, "TC flag mismatch") 189 + assert.Equal(t, tt.expected.RD, h.RD, "RD flag mismatch") 190 + assert.Equal(t, tt.expected.RA, h.RA, "RA flag mismatch") 191 + assert.Equal(t, tt.expected.Z, h.Z, "Z value mismatch") 192 + assert.Equal(t, tt.expected.RCode, h.RCode, "RCode mismatch") 174 193 }) 175 194 } 176 195 } ··· 208 227 }, 209 228 }, 210 229 { 211 - name: "No flags set", 230 + name: "No flags set, different counts", 212 231 header: Header{ 213 232 ID: 0x5678, 214 233 QR: false, ··· 234 253 }, 235 254 }, 236 255 { 237 - name: "Mixed flags", 256 + name: "Mixed flags and counts", 238 257 header: Header{ 239 258 ID: 0x9abc, 240 259 QR: true, ··· 264 283 for _, tt := range tests { 265 284 t.Run(tt.name, func(t *testing.T) { 266 285 encoded := tt.header.Encode() 267 - assert.Equal(t, tt.expected, encoded) 286 + assert.Equal(t, tt.expected, encoded, "Encoded header mismatch") 268 287 }) 269 288 } 270 289 } ··· 287 306 } 288 307 289 308 encoded := originalHeader.Encode() 309 + assert.Len(t, encoded, 12, "Encoded header should be 12 bytes") 290 310 291 311 decodedHeader := &Header{} 292 312 offset, err := decodedHeader.Decode(encoded, 0) 293 313 294 - assert.NoError(t, err) 295 - assert.Equal(t, len(encoded), offset) 296 - assert.Equal(t, originalHeader, *decodedHeader) 314 + assert.NoError(t, err, "Decoding failed unexpectedly") 315 + assert.Equal(t, 12, offset, "Offset after decoding should be 12") 316 + assert.Equal(t, originalHeader, *decodedHeader, "Decoded header does not match original") 297 317 } 298 318 299 319 func TestHeaderEncodeFlagCombinations(t *testing.T) { ··· 316 336 for _, tc := range testCases { 317 337 t.Run(tc.name, func(t *testing.T) { 318 338 encoded := tc.header.Encode() 339 + require.Len(t, encoded, 12, "Encoded header length invariant") 340 + 319 341 flags := binary.BigEndian.Uint16(encoded[2:4]) 320 - assert.Equal(t, tc.expected, flags) 342 + assert.Equal(t, tc.expected, flags, "Flags value mismatch") 321 343 }) 322 344 } 323 345 } ··· 327 349 {0x12, 0x34, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04}, 328 350 {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 329 351 {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, 352 + {0x12, 0x34}, 353 + {0x12, 0x34, 0x81, 0x80, 0x00}, 354 + {}, 330 355 } 331 356 332 357 for _, tc := range testcases { ··· 334 359 } 335 360 336 361 f.Fuzz(func(t *testing.T, data []byte) { 337 - // limit to only 12 bytes 338 - if len(data) > 12 { 339 - data = data[0:12] 340 - } 341 - 342 362 h := &Header{} 343 363 offset, err := h.Decode(data, 0) 344 364 if err != nil { 345 - switch err.(type) { 346 - case *BufferOverflowError, *InvalidLabelError: 347 - // these are expected error types 348 - default: 349 - t.Errorf("unexpected error type: %T", err) 365 + var bofErr *BufferOverflowError 366 + if !errors.As(err, &bofErr) { 367 + t.Errorf("FuzzDecodeHeader: expected BufferOverflowError or wrapped BOF, got %T: %v", err, err) 368 + } 369 + if offset > len(data) { 370 + t.Errorf("FuzzDecodeHeader: offset (%d) > data length (%d) on error", offset, len(data)) 350 371 } 351 372 return 352 373 } 353 374 354 - if offset != len(data) { 355 - t.Errorf("offset (%d) does not match data length (%d)", offset, len(data)) 375 + if len(data) < 12 { 376 + t.Errorf("FuzzDecodeHeader: decoded successfully but input length %d < 12", len(data)) 377 + return 378 + } 379 + if offset != 12 { 380 + t.Errorf("FuzzDecodeHeader: successful decode offset (%d) != 12", offset) 356 381 } 357 382 358 383 if h.OPCode > 15 { 359 - t.Errorf("invalid OPCode: %d", h.OPCode) 384 + t.Errorf("FuzzDecodeHeader: invalid OPCode decoded: %d", h.OPCode) 360 385 } 361 - 362 386 if h.Z > 7 { 363 - t.Errorf("invalid Z value: %d", h.Z) 387 + t.Errorf("FuzzDecodeHeader: invalid Z value decoded: %d", h.Z) 364 388 } 365 - 366 389 if h.RCode > 15 { 367 - t.Errorf("invalid RCode: %d", h.RCode) 390 + t.Errorf("FuzzDecodeHeader: invalid RCode decoded: %d", h.RCode) 368 391 } 369 392 370 - encoded := h.Encode() 371 - if len(encoded) != len(data) { 372 - t.Errorf("encoded length (%d) does not match input length (%d)", len(encoded), len(data)) 373 - 374 - for i := 0; i < len(data); i++ { 375 - t.Errorf("mismatch at position: %d: encoded %02x, input: %02x", i, encoded[i], data[i]) 393 + if len(data) >= 12 { 394 + encoded := h.Encode() 395 + if !bytes.Equal(encoded, data[:12]) { 396 + t.Errorf("FuzzDecodeHeader: encode/decode mismatch\nInput: %x\nEncoded: %x", data[:12], encoded) 376 397 } 377 398 } 378 399 })
+41 -20
message.go
··· 1 1 package magna 2 2 3 3 import ( 4 + "fmt" 4 5 "math/rand" 5 6 ) 6 7 ··· 8 9 func (m *Message) Decode(buf []byte) (err error) { 9 10 offset, err := m.Header.Decode(buf, 0) 10 11 if err != nil { 11 - return err 12 + return fmt.Errorf("failed to decode message header: %w", err) 12 13 } 13 14 14 - for x := 0; x < int(m.Header.QDCount); x++ { 15 + m.Question = make([]Question, 0, m.Header.QDCount) 16 + for i := range m.Header.QDCount { 15 17 var question Question 16 18 offset, err = question.Decode(buf, offset) 17 19 if err != nil { 18 - return err 20 + return fmt.Errorf("failed to decode question #%d: %w", i+1, err) 19 21 } 20 22 21 23 m.Question = append(m.Question, question) 22 24 } 23 25 24 - for x := 0; x < int(m.Header.ANCount); x++ { 26 + m.Answer = make([]ResourceRecord, 0, m.Header.ANCount) 27 + for i := range m.Header.ANCount { 25 28 var rr ResourceRecord 26 29 offset, err = rr.Decode(buf, offset) 27 30 if err != nil { 28 - return err 31 + return fmt.Errorf("failed to decode answer record #%d: %w", i+1, err) 29 32 } 30 33 31 34 m.Answer = append(m.Answer, rr) 32 35 } 33 36 34 - for x := 0; x < int(m.Header.NSCount); x++ { 37 + m.Authority = make([]ResourceRecord, 0, m.Header.NSCount) 38 + for i := range m.Header.NSCount { 35 39 var rr ResourceRecord 36 40 offset, err = rr.Decode(buf, offset) 37 41 if err != nil { 38 - return err 42 + return fmt.Errorf("failed to decode authority record #%d: %w", i+1, err) 39 43 } 40 44 41 45 m.Authority = append(m.Authority, rr) 42 46 } 43 47 44 - for x := 0; x < int(m.Header.ARCount); x++ { 48 + m.Additional = make([]ResourceRecord, 0, m.Header.ARCount) 49 + for i := range m.Header.ARCount { 45 50 var rr ResourceRecord 46 51 offset, err = rr.Decode(buf, offset) 47 52 if err != nil { 48 - return err 53 + return fmt.Errorf("failed to decode additional record #%d: %w", i+1, err) 49 54 } 50 55 51 56 m.Additional = append(m.Additional, rr) ··· 56 61 57 62 // Encode encodes a message to a DNS packet. 58 63 // TODO: set truncation bit if over 512 and udp is protocol 59 - func (m *Message) Encode() []byte { 64 + func (m *Message) Encode() ([]byte, error) { 60 65 m.offsets = make(map[string]uint16) 61 66 bytes := make([]byte, 0, 512) 62 - bytes = append(bytes, m.Header.Encode()...) 67 + 68 + headerBytes := m.Header.Encode() 69 + bytes = append(bytes, headerBytes...) 70 + 71 + var err error 63 72 64 - for _, question := range m.Question { 65 - bytes = question.Encode(bytes, &m.offsets) 73 + for i, question := range m.Question { 74 + bytes, err = question.Encode(bytes, &m.offsets) 75 + if err != nil { 76 + return nil, fmt.Errorf("failed to encode question #%d (%s): %w", i+1, question.QName, err) 77 + } 66 78 } 67 79 68 - for _, answer := range m.Answer { 69 - bytes = answer.Encode(bytes, &m.offsets) 80 + for i, answer := range m.Answer { 81 + bytes, err = answer.Encode(bytes, &m.offsets) 82 + if err != nil { 83 + return nil, fmt.Errorf("failed to encode answer record #%d (%s): %w", i+1, answer.Name, err) 84 + } 70 85 } 71 86 72 - for _, authority := range m.Authority { 73 - bytes = authority.Encode(bytes, &m.offsets) 87 + for i, authority := range m.Authority { 88 + bytes, err = authority.Encode(bytes, &m.offsets) 89 + if err != nil { 90 + return nil, fmt.Errorf("failed to encode authority record #%d (%s): %w", i+1, authority.Name, err) 91 + } 74 92 } 75 93 76 - for _, additional := range m.Additional { 77 - bytes = additional.Encode(bytes, &m.offsets) 94 + for i, additional := range m.Additional { 95 + bytes, err = additional.Encode(bytes, &m.offsets) 96 + if err != nil { 97 + return nil, fmt.Errorf("failed to encode additional record #%d (%s): %w", i+1, additional.Name, err) 98 + } 78 99 } 79 100 80 - return bytes 101 + return bytes, nil 81 102 } 82 103 83 104 func CreateRequest(op OPCode, rd bool) *Message {
+232 -52
message_test.go
··· 3 3 import ( 4 4 "bytes" 5 5 "encoding/binary" 6 + "errors" 6 7 "net" 8 + "strings" 7 9 "testing" 8 10 9 11 "github.com/stretchr/testify/assert" 12 + "github.com/stretchr/testify/require" 10 13 ) 11 14 12 15 func TestMessageDecode(t *testing.T) { 16 + buildQuery := func(id uint16, qname string, qtype DNSType, qclass DNSClass) []byte { 17 + buf := new(bytes.Buffer) 18 + binary.Write(buf, binary.BigEndian, id) 19 + binary.Write(buf, binary.BigEndian, uint16(0x0100)) 20 + binary.Write(buf, binary.BigEndian, uint16(1)) 21 + binary.Write(buf, binary.BigEndian, uint16(0)) 22 + binary.Write(buf, binary.BigEndian, uint16(0)) 23 + binary.Write(buf, binary.BigEndian, uint16(0)) 24 + offsets := make(map[string]uint16) 25 + qBytes, err := encodeDomain([]byte{}, qname, &offsets) 26 + require.NoError(t, err) 27 + buf.Write(qBytes) 28 + binary.Write(buf, binary.BigEndian, uint16(qtype)) 29 + binary.Write(buf, binary.BigEndian, uint16(qclass)) 30 + return buf.Bytes() 31 + } 32 + 33 + buildAnswer := func(id uint16, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdata ResourceRecordData) []byte { 34 + buf := new(bytes.Buffer) 35 + 36 + binary.Write(buf, binary.BigEndian, id) 37 + binary.Write(buf, binary.BigEndian, uint16(0x8180)) 38 + binary.Write(buf, binary.BigEndian, uint16(0)) 39 + binary.Write(buf, binary.BigEndian, uint16(1)) 40 + binary.Write(buf, binary.BigEndian, uint16(0)) 41 + binary.Write(buf, binary.BigEndian, uint16(0)) 42 + rr := ResourceRecord{ 43 + Name: name, 44 + RType: rtype, 45 + RClass: rclass, 46 + TTL: ttl, 47 + RData: rdata, 48 + } 49 + offsets := make(map[string]uint16) 50 + rrBytes, err := rr.Encode([]byte{}, &offsets) 51 + require.NoError(t, err) 52 + buf.Write(rrBytes) 53 + return buf.Bytes() 54 + } 55 + 13 56 tests := []struct { 14 - name string 15 - input []byte 16 - expected Message 17 - wantErr bool 57 + name string 58 + input []byte 59 + expected Message 60 + wantErr bool 61 + wantErrType error 62 + wantErrMsg string 18 63 }{ 19 64 { 20 - name: "Valid DNS message with one question", 21 - input: func() []byte { 22 - buf := new(bytes.Buffer) 23 - binary.Write(buf, binary.BigEndian, uint16(1234)) 24 - binary.Write(buf, binary.BigEndian, uint16(0x0100)) 25 - binary.Write(buf, binary.BigEndian, uint16(1)) 26 - binary.Write(buf, binary.BigEndian, uint16(0)) 27 - binary.Write(buf, binary.BigEndian, uint16(0)) 28 - binary.Write(buf, binary.BigEndian, uint16(0)) 29 - buf.Write([]byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}) 30 - binary.Write(buf, binary.BigEndian, uint16(1)) 31 - binary.Write(buf, binary.BigEndian, uint16(1)) 32 - return buf.Bytes() 33 - }(), 65 + name: "Valid DNS query message with one question", 66 + input: buildQuery(1234, "www.example.com", AType, IN), 34 67 expected: Message{ 35 68 Header: Header{ 36 69 ID: 1234, 37 70 QR: false, 38 71 RD: true, 39 - OPCode: 0, 72 + OPCode: OPCode(0), 40 73 QDCount: 1, 74 + RCode: NOERROR, 41 75 }, 42 76 Question: []Question{ 43 77 { 44 78 QName: "www.example.com", 45 - QType: 1, 46 - QClass: 1, 79 + QType: AType, 80 + QClass: IN, 47 81 }, 48 82 }, 83 + Answer: []ResourceRecord{}, 84 + Additional: []ResourceRecord{}, 85 + Authority: []ResourceRecord{}, 49 86 }, 50 87 wantErr: false, 51 88 }, 52 89 { 53 - name: "Valid DNS message with one answer", 54 - input: func() []byte { 55 - buf := new(bytes.Buffer) 56 - binary.Write(buf, binary.BigEndian, uint16(5678)) 57 - binary.Write(buf, binary.BigEndian, uint16(0x8180)) 58 - binary.Write(buf, binary.BigEndian, uint16(0)) 59 - binary.Write(buf, binary.BigEndian, uint16(1)) 60 - binary.Write(buf, binary.BigEndian, uint16(0)) 61 - binary.Write(buf, binary.BigEndian, uint16(0)) 62 - buf.Write([]byte{3, 'w', 'w', 'w', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0}) 63 - binary.Write(buf, binary.BigEndian, uint16(1)) 64 - binary.Write(buf, binary.BigEndian, uint16(1)) 65 - binary.Write(buf, binary.BigEndian, uint32(3600)) 66 - binary.Write(buf, binary.BigEndian, uint16(4)) 67 - binary.Write(buf, binary.BigEndian, uint32(0x0A000001)) 68 - return buf.Bytes() 69 - }(), 90 + name: "Valid DNS answer message with one A record", 91 + input: buildAnswer(5678, "www.example.com", AType, IN, 3600, 92 + &A{Address: net.ParseIP("10.0.0.1").To4()}, 93 + ), 70 94 expected: Message{ 71 95 Header: Header{ 72 96 ID: 5678, ··· 78 102 RCode: 0, 79 103 ANCount: 1, 80 104 }, 105 + Question: []Question{}, 81 106 Answer: []ResourceRecord{ 82 107 { 83 108 Name: "www.example.com", 84 - RType: 1, 85 - RClass: 1, 109 + RType: AType, 110 + RClass: IN, 86 111 TTL: 3600, 87 112 RDLength: 4, 88 - RData: &A{net.IP([]byte{10, 0, 0, 1})}, 113 + RData: &A{Address: net.IP([]byte{10, 0, 0, 1})}, 89 114 }, 90 115 }, 116 + Additional: []ResourceRecord{}, 117 + Authority: []ResourceRecord{}, 91 118 }, 92 119 wantErr: false, 93 120 }, 94 121 { 95 - name: "Invalid input - empty buffer", 96 - input: []byte{}, 97 - wantErr: true, 122 + name: "Invalid input - empty buffer", 123 + input: []byte{}, 124 + wantErr: true, 125 + wantErrType: &BufferOverflowError{}, 126 + wantErrMsg: "failed to decode message header: header decode: failed to read ID", 127 + }, 128 + { 129 + name: "Invalid input - truncated header", 130 + input: []byte{0x12, 0x34}, 131 + wantErr: true, 132 + wantErrType: &BufferOverflowError{}, 133 + wantErrMsg: "failed to decode message header: header decode: failed to read flags", 134 + }, 135 + { 136 + name: "Invalid input - truncated question name", 137 + input: func() []byte { 138 + buf := new(bytes.Buffer) 139 + binary.Write(buf, binary.BigEndian, uint16(1235)) 140 + binary.Write(buf, binary.BigEndian, uint16(0x0100)) 141 + binary.Write(buf, binary.BigEndian, uint16(1)) 142 + binary.Write(buf, binary.BigEndian, uint16(0)) 143 + binary.Write(buf, binary.BigEndian, uint16(0)) 144 + binary.Write(buf, binary.BigEndian, uint16(0)) 145 + buf.Write([]byte{7, 'e', 'x', 'a'}) 146 + return buf.Bytes() 147 + }(), 148 + wantErr: true, 149 + wantErrType: &BufferOverflowError{}, 150 + wantErrMsg: "failed to decode question #1:", 151 + }, 152 + { 153 + name: "Invalid input - truncated answer record data", 154 + input: func() []byte { 155 + buf := new(bytes.Buffer) 156 + 157 + binary.Write(buf, binary.BigEndian, uint16(5679)) 158 + binary.Write(buf, binary.BigEndian, uint16(0x8180)) 159 + binary.Write(buf, binary.BigEndian, uint16(0)) 160 + binary.Write(buf, binary.BigEndian, uint16(1)) 161 + binary.Write(buf, binary.BigEndian, uint16(0)) 162 + binary.Write(buf, binary.BigEndian, uint16(0)) 163 + 164 + offsets := make(map[string]uint16) 165 + nameBytes, _ := encodeDomain([]byte{}, "example.com", &offsets) 166 + buf.Write(nameBytes) 167 + binary.Write(buf, binary.BigEndian, uint16(AType)) 168 + binary.Write(buf, binary.BigEndian, uint16(IN)) 169 + binary.Write(buf, binary.BigEndian, uint32(300)) 170 + binary.Write(buf, binary.BigEndian, uint16(4)) 171 + 172 + buf.Write([]byte{192, 168}) 173 + return buf.Bytes() 174 + }(), 175 + wantErr: true, 176 + wantErrType: &BufferOverflowError{}, 177 + wantErrMsg: "failed to decode answer record #1:", 98 178 }, 99 179 } 100 180 ··· 104 184 err := m.Decode(tt.input) 105 185 106 186 if tt.wantErr { 107 - assert.Error(t, err) 187 + assert.Error(t, err, "Expected an error but got nil") 188 + if tt.wantErrType != nil { 189 + assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T, expected %T", err, tt.wantErrType) 190 + } 191 + if tt.wantErrMsg != "" { 192 + assert.ErrorContains(t, err, tt.wantErrMsg, "Error message mismatch") 193 + } 108 194 } else { 109 - assert.NoError(t, err) 110 - assert.Equal(t, tt.expected.Header, m.Header) 111 - assert.Equal(t, tt.expected.Question, m.Question) 112 - assert.Equal(t, tt.expected.Answer, m.Answer) 113 - assert.Equal(t, tt.expected.Authority, m.Authority) 114 - assert.Equal(t, tt.expected.Additional, m.Additional) 195 + assert.NoError(t, err, "Expected no error but got one") 196 + 197 + assert.Equal(t, tt.expected.Header.ID, m.Header.ID, "Header ID mismatch") 198 + assert.Equal(t, tt.expected.Header.QR, m.Header.QR, "Header QR mismatch") 199 + assert.Equal(t, tt.expected.Header.OPCode, m.Header.OPCode, "Header OPCode mismatch") 200 + assert.Equal(t, tt.expected.Header.RCode, m.Header.RCode, "Header RCode mismatch") 201 + assert.Equal(t, tt.expected.Header.QDCount, m.Header.QDCount, "Header QDCount mismatch") 202 + assert.Equal(t, tt.expected.Header.ANCount, m.Header.ANCount, "Header ANCount mismatch") 203 + 204 + assert.Equal(t, tt.expected.Question, m.Question, "Question section mismatch") 205 + assert.Equal(t, tt.expected.Answer, m.Answer, "Answer section mismatch") 206 + assert.Equal(t, tt.expected.Authority, m.Authority, "Authority section mismatch") 207 + assert.Equal(t, tt.expected.Additional, m.Additional, "Additional section mismatch") 115 208 } 116 209 }) 117 210 } 118 211 } 119 212 213 + func TestMessageEncodeDecodeRoundTrip(t *testing.T) { 214 + tests := []struct { 215 + name string 216 + message *Message 217 + }{ 218 + { 219 + name: "Query with one question", 220 + message: CreateRequest(QUERY, true).AddQuestion(Question{ 221 + QName: "google.com", 222 + QType: AType, 223 + QClass: IN, 224 + }), 225 + }, 226 + { 227 + name: "Response with one A answer", 228 + message: &Message{ 229 + Header: Header{ 230 + ID: 12345, QR: true, OPCode: QUERY, RD: true, RA: true, RCode: NOERROR, ANCount: 1, 231 + }, 232 + Question: []Question{}, 233 + Answer: []ResourceRecord{ 234 + {Name: "test.local", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.ParseIP("192.0.2.1").To4()}}, 235 + }, 236 + Additional: []ResourceRecord{}, 237 + Authority: []ResourceRecord{}, 238 + }, 239 + }, 240 + { 241 + name: "Response with multiple answers and compression", 242 + message: &Message{ 243 + Header: Header{ID: 54321, QR: true, RCode: NOERROR, ANCount: 2}, 244 + Question: []Question{}, 245 + Answer: []ResourceRecord{ 246 + {Name: "www.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.2").To4()}}, 247 + {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.3").To4()}}, 248 + }, 249 + Additional: []ResourceRecord{}, 250 + Authority: []ResourceRecord{}, 251 + }, 252 + }, 253 + { 254 + name: "Message with various record types", 255 + message: &Message{ 256 + Header: Header{ID: 1111, QR: true, RCode: NOERROR, ANCount: 3}, 257 + Question: []Question{}, 258 + Answer: []ResourceRecord{ 259 + {Name: "example.com", RType: MXType, RClass: IN, TTL: 3600, RDLength: 9, RData: &MX{Preference: 10, Exchange: "mail.example.com"}}, 260 + {Name: "mail.example.com", RType: AType, RClass: IN, TTL: 300, RDLength: 4, RData: &A{net.ParseIP("192.0.2.4").To4()}}, 261 + {Name: "example.com", RType: TXTType, RClass: IN, TTL: 600, RDLength: 36, RData: &TXT{TxtData: []string{"v=spf1 include:_spf.google.com ~all"}}}, 262 + }, 263 + Additional: []ResourceRecord{}, 264 + Authority: []ResourceRecord{}, 265 + }, 266 + }, 267 + } 268 + 269 + for _, tt := range tests { 270 + t.Run(tt.name, func(t *testing.T) { 271 + encodedBytes, err := tt.message.Encode() 272 + require.NoError(t, err, "Encoding failed unexpectedly") 273 + require.NotEmpty(t, encodedBytes, "Encoded bytes should not be empty") 274 + 275 + decodedMsg := &Message{} 276 + err = decodedMsg.Decode(encodedBytes) 277 + require.NoError(t, err, "Decoding failed unexpectedly") 278 + 279 + assert.Equal(t, tt.message.Header.ID, decodedMsg.Header.ID, "Header ID mismatch") 280 + assert.Equal(t, tt.message.Header.QR, decodedMsg.Header.QR, "Header QR mismatch") 281 + assert.Equal(t, tt.message.Header.OPCode, decodedMsg.Header.OPCode, "Header OPCode mismatch") 282 + assert.Equal(t, tt.message.Header.RCode, decodedMsg.Header.RCode, "Header RCode mismatch") 283 + 284 + assert.Equal(t, tt.message.Question, decodedMsg.Question, "Question section mismatch") 285 + assert.Equal(t, tt.message.Answer, decodedMsg.Answer, "Answer section mismatch") 286 + assert.Equal(t, tt.message.Authority, decodedMsg.Authority, "Authority section mismatch") 287 + assert.Equal(t, tt.message.Additional, decodedMsg.Additional, "Additional section mismatch") 288 + }) 289 + } 290 + } 291 + 120 292 func FuzzDecodeMessage(f *testing.F) { 121 293 testcases := [][]byte{ 122 294 { ··· 131 303 } 132 304 f.Fuzz(func(t *testing.T, msg []byte) { 133 305 var m Message 134 - m.Decode(msg) 306 + err := m.Decode(msg) 307 + if err != nil { 308 + var bufErr *BufferOverflowError 309 + var labelErr *InvalidLabelError 310 + var compErr *DomainCompressionError 311 + if !(errors.As(err, &bufErr) || errors.As(err, &labelErr) || errors.As(err, &compErr) || strings.Contains(err.Error(), "record:")) { 312 + t.Errorf("FuzzDecodeMessage: unexpected error type %T: %v for input %x", err, err, msg) 313 + } 314 + } 135 315 }) 136 316 }
+16 -8
question.go
··· 1 1 package magna 2 2 3 - import "encoding/binary" 3 + import ( 4 + "encoding/binary" 5 + "fmt" 6 + ) 4 7 5 8 // Decode decodes a question from buf at the offset 6 9 func (q *Question) Decode(buf []byte, offset int) (int, error) { 7 10 var err error 8 - q.QName, offset, err = decode_domain(buf, offset) 11 + q.QName, offset, err = decodeDomain(buf, offset) 9 12 if err != nil { 10 - return offset, err 13 + return offset, fmt.Errorf("question decode: failed to decode QName: %w", err) 11 14 } 12 15 13 16 qtype, offset, err := getU16(buf, offset) 14 17 if err != nil { 15 - return offset, err 18 + return offset, fmt.Errorf("question decode: failed to decode QType for %s: %w", q.QName, err) 16 19 } 17 20 18 21 qclass, offset, err := getU16(buf, offset) 19 22 if err != nil { 20 - return offset, err 23 + return offset, fmt.Errorf("question decode: failed to decode QClass for %s: %w", q.QName, err) 21 24 } 22 25 23 26 q.QType = DNSType(qtype) ··· 26 29 } 27 30 28 31 // Encode serializes a Question into bytes, using a map to handle domain name compression offsets. 29 - func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) []byte { 30 - bytes = encode_domain(bytes, q.QName, offsets) 32 + func (q *Question) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 33 + var err error 34 + bytes, err = encodeDomain(bytes, q.QName, offsets) 35 + if err != nil { 36 + return nil, fmt.Errorf("question encode: failed to encode QName %s: %w", q.QName, err) 37 + } 38 + 31 39 bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QType)) 32 40 bytes = binary.BigEndian.AppendUint16(bytes, uint16(q.QClass)) 33 - return bytes 41 + return bytes, nil 34 42 }
+103 -82
question_test.go
··· 1 1 package magna 2 2 3 3 import ( 4 + "errors" 4 5 "testing" 5 6 6 7 "github.com/stretchr/testify/assert" 8 + "github.com/stretchr/testify/require" 7 9 ) 8 10 9 11 func TestQuestionDecode(t *testing.T) { ··· 13 15 expectedOffset int 14 16 expected Question 15 17 expectedErr error 18 + wantErrMsg string 16 19 }{ 17 20 { 18 21 name: "Valid question - example.com A IN", ··· 37 40 expectedErr: nil, 38 41 }, 39 42 { 40 - name: "Invalid domain name", 41 - input: []byte{255, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1}, 43 + name: "Invalid domain name - label too long", 44 + input: []byte{64, 'i', 'n', 'v', 'a', 'l', 'i', 'd', 0, 0, 1, 0, 1}, 42 45 expectedOffset: 13, 43 46 expected: Question{}, 44 - expectedErr: &BufferOverflowError{}, 47 + expectedErr: &InvalidLabelError{}, 48 + wantErrMsg: "question decode: failed to decode QName: invalid domain label: length 64 exceeds maximum 63", 49 + }, 50 + { 51 + name: "Invalid domain name - compression loop", 52 + input: []byte{0xC0, 0x00, 0, 1, 0, 1}, 53 + expectedOffset: 6, 54 + expected: Question{}, 55 + expectedErr: &DomainCompressionError{}, 56 + wantErrMsg: "question decode: failed to decode QName: invalid domain compression: pointer loop detected", 45 57 }, 46 58 { 47 59 name: "Insufficient buffer for QType", ··· 49 61 expectedOffset: 14, 50 62 expected: Question{QName: "example.com"}, 51 63 expectedErr: &BufferOverflowError{}, 64 + wantErrMsg: "question decode: failed to decode QType for example.com: buffer overflow", 52 65 }, 53 66 { 54 67 name: "Insufficient buffer for QClass", ··· 56 69 expectedOffset: 16, 57 70 expected: Question{QName: "example.com", QType: DNSType(1)}, 58 71 expectedErr: &BufferOverflowError{}, 72 + wantErrMsg: "question decode: failed to decode QClass for example.com: buffer overflow", 59 73 }, 60 74 } 61 75 ··· 64 78 q := &Question{} 65 79 offset, err := q.Decode(tt.input, 0) 66 80 67 - assert.Equal(t, tt.expectedOffset, offset) 68 - 69 81 if tt.expectedErr != nil { 70 - assert.Error(t, err) 71 - assert.IsType(t, tt.expectedErr, err) 82 + assert.Error(t, err, "Expected an error but got nil") 83 + assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr) 84 + if tt.wantErrMsg != "" { 85 + assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch") 86 + } 87 + assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on error") 72 88 } else { 73 - assert.NoError(t, err) 74 - assert.Equal(t, tt.expected, *q) 89 + assert.NoError(t, err, "Expected no error but got one") 90 + assert.Equal(t, tt.expected, *q, "Decoded question mismatch") 91 + assert.Equal(t, tt.expectedOffset, offset, "Offset mismatch on success") 75 92 } 76 93 }) 77 94 } ··· 79 96 80 97 func TestQuestionEncode(t *testing.T) { 81 98 tests := []struct { 82 - name string 83 - question Question 84 - offsets map[string]uint16 85 - expected []byte 99 + name string 100 + question Question 101 + initialBuf []byte 102 + offsets map[string]uint16 103 + expected []byte 104 + expectedErr error 105 + wantErrMsg string 106 + newOffsets map[string]uint16 86 107 }{ 87 108 { 88 109 name: "Simple domain - example.com A IN", ··· 91 112 QType: DNSType(1), 92 113 QClass: DNSClass(1), 93 114 }, 94 - offsets: make(map[string]uint16), 95 - expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1}, 115 + initialBuf: nil, 116 + offsets: make(map[string]uint16), 117 + expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1}, 118 + newOffsets: map[string]uint16{"example.com": 0, "com": 8}, 96 119 }, 97 120 { 98 121 name: "Subdomain - subdomain.example.com AAAA IN", ··· 101 124 QType: DNSType(28), 102 125 QClass: DNSClass(1), 103 126 }, 104 - offsets: make(map[string]uint16), 105 - expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1}, 127 + initialBuf: nil, 128 + offsets: make(map[string]uint16), 129 + expected: []byte{9, 's', 'u', 'b', 'd', 'o', 'm', 'a', 'i', 'n', 7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 28, 0, 1}, 130 + newOffsets: map[string]uint16{"subdomain.example.com": 0, "example.com": 10, "com": 18}, 106 131 }, 107 132 { 108 133 name: "Different class - example.com MX CH", ··· 111 136 QType: DNSType(15), 112 137 QClass: DNSClass(3), 113 138 }, 114 - offsets: make(map[string]uint16), 115 - expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3}, 139 + initialBuf: nil, 140 + offsets: make(map[string]uint16), 141 + expected: []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 15, 0, 3}, 142 + newOffsets: map[string]uint16{"example.com": 0, "com": 8}, 116 143 }, 117 144 { 118 145 name: "Domain compression - example.com after subdomain.example.com", ··· 121 148 QType: DNSType(1), 122 149 QClass: DNSClass(1), 123 150 }, 151 + initialBuf: nil, 124 152 offsets: map[string]uint16{ 125 - "com": 22, 126 - "example.com": 19, 153 + "subdomain.example.com": 0, 154 + "example.com": 10, 155 + "com": 18, 156 + }, 157 + expected: []byte{0xC0, 0x0a, 0x00, 0x01, 0x00, 0x01}, 158 + newOffsets: map[string]uint16{ 159 + "subdomain.example.com": 0, 160 + "example.com": 10, 161 + "com": 18, 162 + }, 163 + }, 164 + { 165 + name: "Encode with initial buffer", 166 + question: Question{ 167 + QName: "test.org", 168 + QType: AType, 169 + QClass: IN, 170 + }, 171 + initialBuf: []byte{0xAA, 0xBB}, 172 + offsets: make(map[string]uint16), 173 + expected: []byte{0xAA, 0xBB, 4, 't', 'e', 's', 't', 3, 'o', 'r', 'g', 0, 0, 1, 0, 1}, 174 + newOffsets: map[string]uint16{"test.org": 2, "org": 7}, 175 + }, 176 + { 177 + name: "Encode invalid domain - label too long", 178 + question: Question{ 179 + QName: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.com", 180 + QType: AType, 181 + QClass: IN, 127 182 }, 128 - expected: []byte{0xC0, 0x13, 0x00, 0x01, 0x00, 0x01}, 183 + initialBuf: nil, 184 + offsets: make(map[string]uint16), 185 + expected: nil, 186 + expectedErr: &InvalidLabelError{}, 187 + wantErrMsg: "question encode: failed to encode QName", 188 + newOffsets: map[string]uint16{}, 129 189 }, 130 190 } 131 191 132 192 for _, tt := range tests { 133 193 t.Run(tt.name, func(t *testing.T) { 134 - result := tt.question.Encode(nil, &tt.offsets) 135 - assert.Equal(t, tt.expected, result) 194 + currentOffsets := make(map[string]uint16) 195 + for k, v := range tt.offsets { 196 + currentOffsets[k] = v 197 + } 198 + 199 + result, err := tt.question.Encode(tt.initialBuf, &currentOffsets) 136 200 137 - if len(tt.offsets) == 0 { 138 - expectedOffsets := map[string]uint16{ 139 - tt.question.QName: 0, 140 - } 141 - for i := 0; i < len(tt.question.QName); i++ { 142 - if tt.question.QName[i] == '.' { 143 - expectedOffsets[tt.question.QName[i+1:]] = uint16(i + 1) 144 - } 201 + if tt.expectedErr != nil { 202 + assert.Error(t, err, "Expected an error but got nil") 203 + assert.True(t, errors.Is(err, tt.expectedErr), "Error type mismatch. Got %T, expected %T", err, tt.expectedErr) 204 + if tt.wantErrMsg != "" { 205 + assert.ErrorContains(t, err, tt.wantErrMsg, "Wrapped error message mismatch") 145 206 } 146 - assert.Equal(t, expectedOffsets, tt.offsets) 207 + } else { 208 + assert.NoError(t, err, "Expected no error but got one") 209 + assert.Equal(t, tt.expected, result, "Encoded question mismatch") 210 + assert.Equal(t, tt.newOffsets, currentOffsets, "Final offsets mismatch") 147 211 } 148 212 }) 149 213 } ··· 183 247 for _, tt := range tests { 184 248 t.Run(tt.name, func(t *testing.T) { 185 249 offsets := make(map[string]uint16) 186 - encoded := tt.question.Encode(nil, &offsets) 250 + encoded, err := tt.question.Encode(nil, &offsets) 251 + require.NoError(t, err, "Encoding failed") 187 252 188 253 decodedQuestion := &Question{} 189 - _, err := decodedQuestion.Decode(encoded, 0) 254 + offset, err := decodedQuestion.Decode(encoded, 0) 190 255 191 - assert.NoError(t, err) 192 - assert.Equal(t, tt.question, *decodedQuestion) 256 + assert.NoError(t, err, "Decoding failed") 257 + assert.Equal(t, len(encoded), offset, "Offset after decoding should match encoded length") 258 + assert.Equal(t, tt.question, *decodedQuestion, "Decoded question does not match original") 193 259 }) 194 260 } 195 261 } 196 - 197 - func TestQuestionEncodeWithExistingBuffer(t *testing.T) { 198 - question := Question{ 199 - QName: "example.com", 200 - QType: DNSType(1), 201 - QClass: DNSClass(1), 202 - } 203 - 204 - existingBuffer := []byte{0xFF, 0xFF, 0xFF, 0xFF} 205 - offsets := make(map[string]uint16) 206 - 207 - result := question.Encode(existingBuffer, &offsets) 208 - 209 - expected := append( 210 - existingBuffer, 211 - []byte{7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0, 0, 1, 0, 1}..., 212 - ) 213 - 214 - assert.Equal(t, expected, result) 215 - } 216 - 217 - func TestQuestionEncodeLongDomainName(t *testing.T) { 218 - longLabel := make([]byte, 63) 219 - for i := range longLabel { 220 - longLabel[i] = 'a' 221 - } 222 - longDomainName := string(longLabel) + "." + string(longLabel) + "." + string(longLabel) + "." + string(longLabel[:61]) 223 - 224 - question := Question{ 225 - QName: longDomainName, 226 - QType: DNSType(1), 227 - QClass: DNSClass(1), 228 - } 229 - 230 - offsets := make(map[string]uint16) 231 - encoded := question.Encode(nil, &offsets) 232 - 233 - assert.Equal(t, 259, len(encoded)) 234 - 235 - decodedQuestion := &Question{} 236 - _, err := decodedQuestion.Decode(encoded, 0) 237 - 238 - assert.NoError(t, err) 239 - assert.Equal(t, question, *decodedQuestion) 240 - }
+208 -126
resource_record.go
··· 10 10 func (a *A) Decode(buf []byte, offset int, rdlength int) (int, error) { 11 11 bytes, offset, err := getSlice(buf, offset, rdlength) 12 12 if err != nil { 13 - return offset, err 13 + return offset, fmt.Errorf("A record: failed to read address data: %w", err) 14 14 } 15 15 16 16 a.Address = net.IP(bytes) 17 - return offset, err 17 + if a.Address.To4() == nil { 18 + return offset, fmt.Errorf("A record: decoded data is not a valid IPv4 address: %v", bytes) 19 + } 20 + return offset, nil 18 21 } 19 22 20 - func (a *A) Encode(bytes []byte, offsets *map[string]uint16) []byte { 21 - return append(bytes, a.Address.To4()...) 23 + func (a *A) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 24 + ipv4 := a.Address.To4() 25 + if ipv4 == nil { 26 + return nil, fmt.Errorf("A record: cannot encode non-IPv4 address %s", a.Address.String()) 27 + } 28 + 29 + return append(bytes, a.Address.To4()...), nil 22 30 } 23 31 24 32 func (a A) String() string { ··· 27 35 28 36 func (ns *NS) Decode(buf []byte, offset int, rdlength int) (int, error) { 29 37 var err error 30 - ns.NSDName, offset, err = decode_domain(buf, offset) 38 + ns.NSDName, offset, err = decodeDomain(buf, offset) 31 39 if err != nil { 32 - return offset, err 40 + return offset, fmt.Errorf("NS record: failed to decode NSDName: %w", err) 33 41 } 34 42 35 - return offset, err 43 + return offset, nil 36 44 } 37 45 38 - func (ns *NS) Encode(bytes []byte, offsets *map[string]uint16) []byte { 39 - return encode_domain(bytes, ns.NSDName, offsets) 46 + func (ns *NS) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 47 + var err error 48 + bytes, err = encodeDomain(bytes, ns.NSDName, offsets) 49 + if err != nil { 50 + return nil, fmt.Errorf("NS record: failed to encode NSDName %s: %w", ns.NSDName, err) 51 + } 52 + 53 + return bytes, nil 40 54 } 41 55 42 56 func (ns NS) String() string { ··· 45 59 46 60 func (md *MD) Decode(buf []byte, offset int, rdlength int) (int, error) { 47 61 var err error 48 - md.MADName, offset, err = decode_domain(buf, offset) 62 + md.MADName, offset, err = decodeDomain(buf, offset) 49 63 if err != nil { 50 - return offset, err 64 + return offset, fmt.Errorf("MD record: failed to decode MADName %s: %w", md.MADName, err) 51 65 } 52 66 53 - return offset, err 67 + return offset, nil 54 68 } 55 69 56 - func (md *MD) Encode(bytes []byte, offsets *map[string]uint16) []byte { 57 - return encode_domain(bytes, md.MADName, offsets) 70 + func (md *MD) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 71 + var err error 72 + bytes, err = encodeDomain(bytes, md.MADName, offsets) 73 + if err != nil { 74 + return nil, fmt.Errorf("MD record: failed to encode MADName %s: %w", md.MADName, err) 75 + } 76 + 77 + return bytes, nil 58 78 } 59 79 60 80 func (md MD) String() string { ··· 63 83 64 84 func (mf *MF) Decode(buf []byte, offset int, rdlength int) (int, error) { 65 85 var err error 66 - mf.MADName, offset, err = decode_domain(buf, offset) 86 + mf.MADName, offset, err = decodeDomain(buf, offset) 67 87 if err != nil { 68 - return offset, err 88 + return offset, fmt.Errorf("MF record: failed to decode MADName: %w", err) 69 89 } 70 90 71 - return offset, err 91 + return offset, nil 72 92 } 73 93 74 - func (mf *MF) Encode(bytes []byte, offsets *map[string]uint16) []byte { 75 - return encode_domain(bytes, mf.MADName, offsets) 94 + func (mf *MF) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 95 + var err error 96 + bytes, err = encodeDomain(bytes, mf.MADName, offsets) 97 + if err != nil { 98 + return nil, fmt.Errorf("MF record: failed to encode MADName %s: %w", mf.MADName, err) 99 + } 100 + 101 + return bytes, nil 76 102 } 77 103 78 104 func (mf MF) String() string { ··· 81 107 82 108 func (c *CNAME) Decode(buf []byte, offset int, rdlength int) (int, error) { 83 109 var err error 84 - c.CName, offset, err = decode_domain(buf, offset) 110 + c.CName, offset, err = decodeDomain(buf, offset) 85 111 if err != nil { 86 - return offset, err 112 + return offset, fmt.Errorf("CNAME record: failed to decode CNAME: %w", err) 87 113 } 88 114 89 - return offset, err 115 + return offset, nil 90 116 } 91 117 92 - func (c *CNAME) Encode(bytes []byte, offsets *map[string]uint16) []byte { 93 - return encode_domain(bytes, c.CName, offsets) 118 + func (c *CNAME) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 119 + var err error 120 + bytes, err = encodeDomain(bytes, c.CName, offsets) 121 + if err != nil { 122 + return nil, fmt.Errorf("CNAME record: failed to encode CNAME %s: %w", c.CName, err) 123 + } 124 + 125 + return bytes, nil 94 126 } 95 127 96 128 func (c CNAME) String() string { ··· 99 131 100 132 func (soa *SOA) Decode(buf []byte, offset int, rdlength int) (int, error) { 101 133 var err error 102 - soa.MName, offset, err = decode_domain(buf, offset) 134 + soa.MName, offset, err = decodeDomain(buf, offset) 103 135 if err != nil { 104 - return offset, err 136 + return offset, fmt.Errorf("SOA record: failed to decode MName: %w", err) 105 137 } 106 138 107 - soa.RName, offset, err = decode_domain(buf, offset) 139 + soa.RName, offset, err = decodeDomain(buf, offset) 108 140 if err != nil { 109 - return offset, err 141 + return offset, fmt.Errorf("SOA record: failed to decode RName: %w", err) 110 142 } 111 143 112 144 soa.Serial, offset, err = getU32(buf, offset) 113 145 if err != nil { 114 - return offset, err 146 + return offset, fmt.Errorf("SOA record: failed to decode Serial: %w", err) 115 147 } 116 148 117 149 soa.Refresh, offset, err = getU32(buf, offset) 118 150 if err != nil { 119 - return offset, err 151 + return offset, fmt.Errorf("SOA record: failed to decode Refresh: %w", err) 120 152 } 121 153 122 154 soa.Retry, offset, err = getU32(buf, offset) 123 155 if err != nil { 124 - return offset, err 156 + return offset, fmt.Errorf("SOA record: failed to decode Retry: %w", err) 125 157 } 126 158 127 159 soa.Expire, offset, err = getU32(buf, offset) 128 160 if err != nil { 129 - return offset, err 161 + return offset, fmt.Errorf("SOA record: failed to decode Expire: %w", err) 130 162 } 131 163 132 164 soa.Minimum, offset, err = getU32(buf, offset) 133 165 if err != nil { 134 - return offset, err 166 + return offset, fmt.Errorf("SOA record: failed to decode Minimum: %w", err) 135 167 } 136 168 137 - return offset, err 169 + return offset, nil 138 170 } 139 171 140 - func (soa *SOA) Encode(bytes []byte, offsets *map[string]uint16) []byte { 141 - bytes = encode_domain(bytes, soa.MName, offsets) 142 - bytes = encode_domain(bytes, soa.RName, offsets) 172 + func (soa *SOA) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 173 + var err error 174 + bytes, err = encodeDomain(bytes, soa.MName, offsets) 175 + if err != nil { 176 + return nil, fmt.Errorf("SOA record: failed to encode MName %s: %w", soa.MName, err) 177 + } 178 + 179 + bytes, err = encodeDomain(bytes, soa.RName, offsets) 180 + if err != nil { 181 + return nil, fmt.Errorf("SOA record: failed to encode RName %s: %w", soa.RName, err) 182 + } 183 + 143 184 bytes = binary.BigEndian.AppendUint32(bytes, soa.Serial) 144 185 bytes = binary.BigEndian.AppendUint32(bytes, soa.Refresh) 145 186 bytes = binary.BigEndian.AppendUint32(bytes, soa.Retry) 146 187 bytes = binary.BigEndian.AppendUint32(bytes, soa.Expire) 147 188 bytes = binary.BigEndian.AppendUint32(bytes, soa.Minimum) 148 189 149 - return bytes 190 + return bytes, nil 150 191 } 151 192 152 193 func (soa SOA) String() string { ··· 154 195 } 155 196 156 197 func (mb *MB) Decode(buf []byte, offset int, rdlength int) (int, error) { 157 - madname, offset, err := decode_domain(buf, offset) 198 + madname, offset, err := decodeDomain(buf, offset) 158 199 if err != nil { 159 - return offset, err 200 + return offset, fmt.Errorf("MB record: failed to decode MADName: %w", err) 160 201 } 161 202 162 203 mb.MADName = string(madname) 163 - return offset, err 204 + return offset, nil 164 205 } 165 206 166 - func (mb *MB) Encode(bytes []byte, offsets *map[string]uint16) []byte { 167 - return encode_domain(bytes, mb.MADName, offsets) 207 + func (mb *MB) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 208 + var err error 209 + bytes, err = encodeDomain(bytes, mb.MADName, offsets) 210 + if err != nil { 211 + return nil, fmt.Errorf("MB record: failed to encode MADName %s: %w", mb.MADName, err) 212 + } 213 + 214 + return bytes, nil 168 215 } 169 216 170 217 func (mb MB) String() string { ··· 173 220 174 221 func (mg *MG) Decode(buf []byte, offset int, rdlength int) (int, error) { 175 222 var err error 176 - mg.MGMName, offset, err = decode_domain(buf, offset) 223 + mg.MGMName, offset, err = decodeDomain(buf, offset) 177 224 if err != nil { 178 - return offset, err 225 + return offset, fmt.Errorf("MG record: failed to decode MGMName: %w", err) 179 226 } 180 227 181 - return offset, err 228 + return offset, nil 182 229 } 183 230 184 - func (mg *MG) Encode(bytes []byte, offsets *map[string]uint16) []byte { 185 - return encode_domain(bytes, mg.MGMName, offsets) 231 + func (mg *MG) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 232 + var err error 233 + bytes, err = encodeDomain(bytes, mg.MGMName, offsets) 234 + if err != nil { 235 + return nil, fmt.Errorf("MG record: failed to encode MGMName %s: %w", mg.MGMName, err) 236 + } 237 + return bytes, nil 186 238 } 187 239 188 240 func (mg MG) String() string { ··· 191 243 192 244 func (mr *MR) Decode(buf []byte, offset int, rdlength int) (int, error) { 193 245 var err error 194 - mr.NEWName, offset, err = decode_domain(buf, offset) 246 + mr.NEWName, offset, err = decodeDomain(buf, offset) 195 247 if err != nil { 196 - return offset, err 248 + return offset, fmt.Errorf("MR record: failed to decode NEWName: %w", err) 197 249 } 198 250 199 - return offset, err 251 + return offset, nil 200 252 } 201 253 202 - func (mr *MR) Encode(bytes []byte, offsets *map[string]uint16) []byte { 203 - return encode_domain(bytes, mr.NEWName, offsets) 254 + func (mr *MR) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 255 + var err error 256 + bytes, err = encodeDomain(bytes, mr.NEWName, offsets) 257 + if err != nil { 258 + return nil, fmt.Errorf("MR record: failed to encode NEWName: %w", err) 259 + } 260 + 261 + return bytes, nil 204 262 } 205 263 206 264 func (mr MR) String() string { ··· 211 269 var err error 212 270 null.Anything, offset, err = getSlice(buf, offset, int(rdlength)) 213 271 if err != nil { 214 - return offset, err 272 + return offset, fmt.Errorf("NULL record: failed to read data: %w", err) 215 273 } 216 274 217 - return offset, err 275 + return offset, nil 218 276 } 219 277 220 - func (null *NULL) Encode(bytes []byte, offsets *map[string]uint16) []byte { 221 - return append(bytes, null.Anything...) 278 + func (null *NULL) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 279 + return append(bytes, null.Anything...), nil 222 280 } 223 281 224 282 func (null NULL) String() string { ··· 227 285 228 286 func (wks *WKS) Decode(buf []byte, offset int, rdlength int) (int, error) { 229 287 if rdlength < 5 { 230 - return len(buf), &MagnaError{Message: fmt.Sprintf("magna: WKS RDLENGTH too short: %d", rdlength)} 288 + return len(buf), fmt.Errorf("WKS record: RDLENGTH %d is too short, minimum 5 required", rdlength) 231 289 } 232 290 233 291 addressBytes, nextOffset, err := getSlice(buf, offset, 4) 234 292 if err != nil { 235 - return len(buf), fmt.Errorf("magna: WKS error reading address: %w", err) 293 + return len(buf), fmt.Errorf("WKS record: failed to read address: %w", err) 236 294 } 237 295 offset = nextOffset 238 296 wks.Address = net.IP(addressBytes) 239 297 240 298 protocol, nextOffset, err := getU8(buf, offset) 241 299 if err != nil { 242 - return len(buf), fmt.Errorf("magna: WKS error reading protocol: %w", err) 300 + return len(buf), fmt.Errorf("WKS record: failed to read protocol: %w", err) 243 301 } 244 302 offset = nextOffset 245 303 wks.Protocol = protocol ··· 247 305 bitmapLength := rdlength - 5 248 306 wks.BitMap, nextOffset, err = getSlice(buf, offset, bitmapLength) 249 307 if err != nil { 250 - return len(buf), fmt.Errorf("magna: WKS error reading bitmap: %w", err) 308 + return len(buf), fmt.Errorf("WKS record: failed to read bitmap: %w", err) 251 309 } 252 310 offset = nextOffset 253 311 254 312 return offset, nil 255 313 } 256 314 257 - func (wks *WKS) Encode(bytes []byte, offsets *map[string]uint16) []byte { 315 + func (wks *WKS) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 258 316 bytes = append(bytes, wks.Address.To4()...) 259 317 bytes = append(bytes, wks.Protocol) 260 318 bytes = append(bytes, wks.BitMap...) 261 319 262 - return bytes 320 + return bytes, nil 263 321 } 264 322 265 323 func (wks WKS) String() string { 266 - return fmt.Sprintf("%s %d %s", wks.Address.String(), wks.Protocol, wks.BitMap) 324 + return fmt.Sprintf("%s %d %x", wks.Address.String(), wks.Protocol, wks.BitMap) 267 325 } 268 326 269 327 func (ptr *PTR) Decode(buf []byte, offset int, rdlength int) (int, error) { 270 328 var err error 271 - ptr.PTRDName, offset, err = decode_domain(buf, offset) 329 + ptr.PTRDName, offset, err = decodeDomain(buf, offset) 272 330 if err != nil { 273 - return offset, err 331 + return offset, fmt.Errorf("PTR record: failed to decode PTRDName: %w", err) 274 332 } 275 333 276 - return offset, err 334 + return offset, nil 277 335 } 278 336 279 - func (ptr *PTR) Encode(bytes []byte, offsets *map[string]uint16) []byte { 280 - return encode_domain(bytes, ptr.PTRDName, offsets) 337 + func (ptr *PTR) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 338 + var err error 339 + bytes, err = encodeDomain(bytes, ptr.PTRDName, offsets) 340 + if err != nil { 341 + return nil, fmt.Errorf("PTR record: failed to encode PTRD %s: %w", ptr.PTRDName, err) 342 + } 343 + 344 + return bytes, nil 281 345 } 282 346 283 347 func (ptr PTR) String() string { ··· 285 349 } 286 350 287 351 func (hinfo *HINFO) Decode(buf []byte, offset int, rdlength int) (int, error) { 352 + startOffset := offset 288 353 endOffset := offset + rdlength 289 354 if endOffset > len(buf) { 290 355 return len(buf), &BufferOverflowError{Length: len(buf), Offset: endOffset} ··· 295 360 296 361 cpuLen, nextOffset, err := getU8(buf, currentOffset) 297 362 if err != nil { 298 - return len(buf), fmt.Errorf("magna: HINFO error reading CPU length: %w", err) 363 + return len(buf), fmt.Errorf("HINFO record: failed to read CPU length: %w", err) 299 364 } 300 365 currentOffset = nextOffset 301 366 if currentOffset+int(cpuLen) > endOffset { ··· 303 368 } 304 369 cpuBytes, nextOffset, err := getSlice(buf, currentOffset, int(cpuLen)) 305 370 if err != nil { 306 - return len(buf), fmt.Errorf("magna: HINFO error reading CPU data: %w", err) 371 + return len(buf), fmt.Errorf("HINFO record: failed to read CPU data: %w", err) 307 372 } 308 373 currentOffset = nextOffset 309 374 hinfo.CPU = string(cpuBytes) ··· 311 376 osLen, nextOffset, err := getU8(buf, currentOffset) 312 377 if err != nil { 313 378 if currentOffset == endOffset { 314 - return len(buf), &MagnaError{Message: "magna: HINFO missing OS string"} 379 + return len(buf), fmt.Errorf("HINFO record: missing OS length byte at offset %d (expected end: %d)", currentOffset, endOffset) 315 380 } 316 - return len(buf), fmt.Errorf("magna: HINFO error reading OS length: %w", err) 381 + return len(buf), fmt.Errorf("HINFO record: failed to read OS length: %w", err) 317 382 } 318 383 currentOffset = nextOffset 319 384 if currentOffset+int(osLen) > endOffset { ··· 321 386 } 322 387 osBytes, nextOffset, err := getSlice(buf, currentOffset, int(osLen)) 323 388 if err != nil { 324 - return len(buf), fmt.Errorf("magna: HINFO error reading OS data: %w", err) 389 + return len(buf), fmt.Errorf("HINFO record: failed to read OS data: %w", err) 325 390 } 326 391 currentOffset = nextOffset 327 392 hinfo.OS = string(osBytes) 328 393 329 394 if currentOffset != endOffset { 330 - return len(buf), &MagnaError{Message: fmt.Sprintf("magna: HINFO RDATA length mismatch, expected end at %d, ended at %d", endOffset, currentOffset)} 395 + return len(buf), fmt.Errorf("HINFO record: RDATA length mismatch, consumed %d bytes, expected %d", currentOffset-startOffset, rdlength) 331 396 } 332 397 333 398 return currentOffset, nil 334 399 } 335 400 336 - func (hinfo *HINFO) Encode(bytes []byte, offsets *map[string]uint16) []byte { 337 - // XXX: should probally return an error 401 + func (hinfo *HINFO) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 338 402 if len(hinfo.CPU) > 255 { 339 - hinfo.CPU = hinfo.CPU[:255] 403 + return nil, fmt.Errorf("HINFO record: CPU string length %d exceeds maximum 255", len(hinfo.CPU)) 340 404 } 341 405 if len(hinfo.OS) > 255 { 342 - hinfo.OS = hinfo.OS[:255] 406 + return nil, fmt.Errorf("HINFO record: OS string length %d exceeds maximum 255", len(hinfo.OS)) 343 407 } 344 408 345 409 bytes = append(bytes, byte(len(hinfo.CPU))) 346 410 bytes = append(bytes, []byte(hinfo.CPU)...) 347 411 bytes = append(bytes, byte(len(hinfo.OS))) 348 412 bytes = append(bytes, []byte(hinfo.OS)...) 349 - return bytes 413 + return bytes, nil 350 414 } 351 415 352 416 func (hinfo HINFO) String() string { ··· 356 420 func (minfo *MINFO) Decode(buf []byte, offset int, rdlength int) (int, error) { 357 421 var err error 358 422 359 - minfo.RMailBx, offset, err = decode_domain(buf, offset) 423 + minfo.RMailBx, offset, err = decodeDomain(buf, offset) 360 424 if err != nil { 361 - return offset, err 425 + return offset, fmt.Errorf("MINFO record: failed to decode RMailBx: %w", err) 362 426 } 363 427 364 - minfo.EMailBx, offset, err = decode_domain(buf, offset) 428 + minfo.EMailBx, offset, err = decodeDomain(buf, offset) 365 429 if err != nil { 366 - return offset, err 430 + return offset, fmt.Errorf("MINFO record: failed to decode EMailBx: %w", err) 367 431 } 368 432 369 - return offset, err 433 + return offset, nil 370 434 } 371 435 372 - func (minfo *MINFO) Encode(bytes []byte, offsets *map[string]uint16) []byte { 373 - bytes = encode_domain(bytes, minfo.RMailBx, offsets) 374 - bytes = encode_domain(bytes, minfo.EMailBx, offsets) 436 + func (minfo *MINFO) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 437 + var err error 438 + bytes, err = encodeDomain(bytes, minfo.RMailBx, offsets) 439 + if err != nil { 440 + return nil, fmt.Errorf("MINFO record: failed to encode RMailBx %s: %w", minfo.RMailBx, err) 441 + } 442 + 443 + bytes, err = encodeDomain(bytes, minfo.EMailBx, offsets) 444 + if err != nil { 445 + return nil, fmt.Errorf("MINFO record: failed to encode EMailBx %s: %w", minfo.EMailBx, err) 446 + } 375 447 376 - return bytes 448 + return bytes, nil 377 449 } 378 450 379 451 func (minfo MINFO) String() string { ··· 384 456 var err error 385 457 mx.Preference, offset, err = getU16(buf, offset) 386 458 if err != nil { 387 - return offset, err 459 + return offset, fmt.Errorf("MX record: failed to decode Preference: %w", err) 388 460 } 389 461 390 - mx.Exchange, offset, err = decode_domain(buf, offset) 462 + mx.Exchange, offset, err = decodeDomain(buf, offset) 391 463 if err != nil { 392 - return offset, err 464 + return offset, fmt.Errorf("MX record: failed to decode Exchange: %w", err) 393 465 } 394 466 395 - return offset, err 467 + return offset, nil 396 468 } 397 469 398 - func (mx *MX) Encode(bytes []byte, offsets *map[string]uint16) []byte { 470 + func (mx *MX) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 471 + var err error 399 472 bytes = binary.BigEndian.AppendUint16(bytes, mx.Preference) 400 - bytes = encode_domain(bytes, mx.Exchange, offsets) 473 + bytes, err = encodeDomain(bytes, mx.Exchange, offsets) 474 + if err != nil { 475 + return nil, fmt.Errorf("MX record: failed to encode Exchange %s: %w", mx.Exchange, err) 476 + } 401 477 402 - return bytes 478 + return bytes, nil 403 479 } 404 480 405 481 func (mx MX) String() string { ··· 417 493 for currentOffset < endOffset { 418 494 strLen, nextOffsetAfterLen, err := getU8(buf, currentOffset) 419 495 if err != nil { 420 - return len(buf), fmt.Errorf("magna: error reading TXT string length byte: %w", err) 496 + return len(buf), fmt.Errorf("TXT record: failed to read string length byte: %w", err) 421 497 } 422 498 423 499 nextOffsetAfterData := nextOffsetAfterLen + int(strLen) 424 500 if nextOffsetAfterData > endOffset { 425 - return len(buf), &MagnaError{ 426 - Message: fmt.Sprintf("magna: TXT string segment length %d at offset %d exceeds RDLENGTH boundary %d", strLen, nextOffsetAfterLen, endOffset), 427 - } 501 + return len(buf), fmt.Errorf("TXT record: string segment length %d exceeds RDLENGTH boundary %d", strLen, endOffset) 428 502 } 429 503 430 504 strBytes, actualNextOffsetAfterData, err := getSlice(buf, nextOffsetAfterLen, int(strLen)) 431 505 if err != nil { 432 - return len(buf), fmt.Errorf("magna: error reading TXT string data: %w", err) 506 + return len(buf), fmt.Errorf("TXT record: failed to read string data (length %d): %w", strLen, err) 433 507 } 434 508 435 509 txt.TxtData = append(txt.TxtData, string(strBytes)) ··· 437 511 } 438 512 439 513 if currentOffset != endOffset { 440 - return len(buf), &MagnaError{ 441 - Message: fmt.Sprintf("magna: TXT RDATA parsing finished at offset %d, but expected end at %d based on RDLENGTH", currentOffset, endOffset), 442 - } 514 + return len(buf), fmt.Errorf("TXT record: RDATA parsing finished at offset %d, but expected end at %d", currentOffset, endOffset) 443 515 } 444 516 445 517 return currentOffset, nil 446 518 } 447 519 448 - func (txt *TXT) Encode(bytes []byte, offsets *map[string]uint16) []byte { 520 + func (txt *TXT) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 449 521 for _, s := range txt.TxtData { 450 522 if len(s) > 255 { 451 - // XXX: should return probably an error 452 - s = s[:255] 523 + return nil, fmt.Errorf("TXT record: string segment length %d exceeds maximum 255", len(s)) 453 524 } 454 525 bytes = append(bytes, byte(len(s))) 455 526 bytes = append(bytes, []byte(s)...) 456 527 } 457 - return bytes 528 + return bytes, nil 458 529 } 459 530 460 531 func (txt TXT) String() string { ··· 469 540 var err error 470 541 r.Bytes, offset, err = getSlice(buf, offset, int(rdlength)) 471 542 if err != nil { 472 - return offset, err 543 + return offset, fmt.Errorf("reserved record: failed to read data: %w", err) 473 544 } 474 545 475 546 return offset, err 476 547 } 477 548 478 - func (r *Reserved) Encode(bytes []byte, offsets *map[string]uint16) []byte { 479 - return append(bytes, r.Bytes...) 549 + func (r *Reserved) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 550 + return append(bytes, r.Bytes...), nil 480 551 } 481 552 482 553 func (r Reserved) String() string { 483 - return string(r.Bytes) 554 + return fmt.Sprintf("[Reserved Data: %x]", r.Bytes) 484 555 } 485 556 486 557 // Decode decodes a resource record from buf at the offset. 487 558 func (r *ResourceRecord) Decode(buf []byte, offset int) (int, error) { 488 - name, offset, err := decode_domain(buf, offset) 559 + var err error 560 + r.Name, offset, err = decodeDomain(buf, offset) 489 561 if err != nil { 490 - return offset, err 562 + return offset, fmt.Errorf("rr decode: failed to decode record name: %w", err) 491 563 } 492 - r.Name = name 493 564 494 - rtype, offset, err := getU16(buf, offset) 565 + var rtype uint16 566 + rtype, offset, err = getU16(buf, offset) 495 567 if err != nil { 496 - return offset, err 568 + return offset, fmt.Errorf("rr decode: failed to decode RType for %s: %w", r.Name, err) 497 569 } 498 570 r.RType = DNSType(rtype) 499 571 500 - rclass, offset, err := getU16(buf, offset) 572 + var rclass uint16 573 + rclass, offset, err = getU16(buf, offset) 501 574 if err != nil { 502 - return offset, err 575 + return offset, fmt.Errorf("rr decode: failed to decode RClass for %s: %w", r.Name, err) 503 576 } 504 577 r.RClass = DNSClass(rclass) 505 578 506 579 r.TTL, offset, err = getU32(buf, offset) 507 580 if err != nil { 508 - return offset, err 581 + return offset, fmt.Errorf("rr decode: failed to decode TTL for %s: %w", r.Name, err) 509 582 } 510 583 511 584 r.RDLength, offset, err = getU16(buf, offset) 512 585 if err != nil { 513 - return offset, err 586 + return offset, fmt.Errorf("rr decode: failed to decode RDLength for %s: %w", r.Name, err) 514 587 } 515 588 516 589 switch r.RType { ··· 553 626 if r.RData != nil { 554 627 offset, err = r.RData.Decode(buf, offset, int(r.RDLength)) 555 628 if err != nil { 556 - return offset, err 629 + return offset, fmt.Errorf("rr decode: failed to decode RData for %s (%s): %w", r.Name, r.RType.String(), err) 557 630 } 558 631 } 559 632 ··· 561 634 } 562 635 563 636 // Encode encdoes a resource record and returns the input bytes appened. 564 - func (r *ResourceRecord) Encode(bytes []byte, offsets *map[string]uint16) []byte { 565 - bytes = encode_domain(bytes, r.Name, offsets) 637 + func (r *ResourceRecord) Encode(bytes []byte, offsets *map[string]uint16) ([]byte, error) { 638 + var err error 639 + bytes, err = encodeDomain(bytes, r.Name, offsets) 640 + if err != nil { 641 + return nil, fmt.Errorf("rr encode: failed to encode record name %s: %w", r.Name, err) 642 + } 643 + 566 644 bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RType)) 567 645 bytes = binary.BigEndian.AppendUint16(bytes, uint16(r.RClass)) 568 646 bytes = binary.BigEndian.AppendUint32(bytes, r.TTL) 569 647 570 648 rdata_start := len(bytes) 571 649 bytes = binary.BigEndian.AppendUint16(bytes, 0) 572 - bytes = r.RData.Encode(bytes, offsets) 650 + bytes, err = r.RData.Encode(bytes, offsets) 651 + if err != nil { 652 + return nil, fmt.Errorf("rr encode: failed to encode RData for %s (%s): %w", r.Name, r.RType.String(), err) 653 + } 654 + 573 655 rdata_length := uint16(len(bytes) - rdata_start - 2) 574 656 binary.BigEndian.PutUint16(bytes[rdata_start:rdata_start+2], rdata_length) 575 657 576 - return bytes 658 + return bytes, nil 577 659 }
+472 -83
resource_record_test.go
··· 2 2 3 3 import ( 4 4 "encoding/binary" 5 + "errors" 5 6 "net" 6 7 "testing" 7 8 ··· 9 10 "github.com/stretchr/testify/require" 10 11 ) 11 12 12 - func TestTXTRecord(t *testing.T) { 13 - rdataBytes := []byte{0x03, 'a', 'b', 'c', 0x03, 'd', 'e', 'f'} 14 - rdlength := uint16(len(rdataBytes)) 13 + func buildRRBytes(t *testing.T, name string, rtype DNSType, rclass DNSClass, ttl uint32, rdataBytes []byte) []byte { 14 + t.Helper() 15 + buf := []byte{} 16 + offsets := make(map[string]uint16) 15 17 16 - buf := []byte{0x00} 17 - buf = binary.BigEndian.AppendUint16(buf, uint16(TXTType)) 18 - buf = binary.BigEndian.AppendUint16(buf, uint16(IN)) 19 - buf = binary.BigEndian.AppendUint32(buf, 3600) 20 - buf = binary.BigEndian.AppendUint16(buf, rdlength) 18 + encodedName, err := encodeDomain(buf, name, &offsets) 19 + require.NoError(t, err, "Failed to encode name in test helper") 20 + buf = encodedName 21 + 22 + buf = binary.BigEndian.AppendUint16(buf, uint16(rtype)) 23 + buf = binary.BigEndian.AppendUint16(buf, uint16(rclass)) 24 + buf = binary.BigEndian.AppendUint32(buf, ttl) 25 + 26 + buf = binary.BigEndian.AppendUint16(buf, uint16(len(rdataBytes))) 21 27 buf = append(buf, rdataBytes...) 22 28 23 - rr := &ResourceRecord{} 24 - offset, err := rr.Decode(buf, 0) 25 - require.NoError(t, err) 26 - assert.Equal(t, len(buf), offset) 27 - require.IsType(t, &TXT{}, rr.RData) 29 + return buf 30 + } 28 31 29 - txtData := rr.RData.(*TXT) 32 + func encodeRData(t *testing.T, rdata ResourceRecordData) []byte { 33 + t.Helper() 34 + buf := []byte{} 35 + offsets := make(map[string]uint16) 36 + encodedRData, err := rdata.Encode(buf, &offsets) 37 + require.NoError(t, err, "Failed to encode RDATA in test helper") 38 + return encodedRData 39 + } 30 40 31 - expectedDecodedData := []string{"abc", "def"} 32 - assert.Equal(t, expectedDecodedData, txtData.TxtData, "Decoded TXT data does not match expected concatenation") 41 + func TestARecord(t *testing.T) { 42 + addr := net.ParseIP("192.168.1.1").To4() 43 + rdataBytes := []byte(addr) 44 + a := &A{} 45 + 46 + offset, err := a.Decode([]byte{}, 0, 4) 47 + assert.Error(t, err, "Decode should fail with empty buffer") 48 + assert.True(t, errors.Is(err, &BufferOverflowError{})) 49 + 50 + offset, err = a.Decode(rdataBytes, 0, 4) 51 + assert.NoError(t, err) 52 + assert.Equal(t, 4, offset) 53 + assert.Equal(t, addr, a.Address) 54 + 55 + _, err = a.Decode([]byte{1, 2, 3}, 0, 3) 56 + assert.Error(t, err) 57 + assert.Contains(t, err.Error(), "A record:") 58 + 59 + _, err = a.Decode([]byte{1, 2, 3, 4, 5}, 0, 5) 60 + assert.Error(t, err) 61 + assert.Contains(t, err.Error(), "A record:") 62 + 63 + addr = net.ParseIP("192.168.1.1").To4() 64 + err = nil 65 + aEncode := &A{Address: addr} 66 + encoded := encodeRData(t, aEncode) 67 + assert.NoError(t, err) 68 + assert.Equal(t, rdataBytes, encoded) 69 + } 70 + 71 + func TestNSRecord(t *testing.T) { 72 + nsName := "ns1.example.com" 73 + offsets := make(map[string]uint16) 74 + rdataBytes, _ := encodeDomain([]byte{}, nsName, &offsets) 75 + ns := &NS{} 33 76 34 - txtToEncode := &TXT{TxtData: []string{"test"}} 35 - expectedEncodedRdata := []byte{0x04, 't', 'e', 's', 't'} 77 + offset, err := ns.Decode(rdataBytes, 0, len(rdataBytes)) 78 + assert.NoError(t, err) 79 + assert.Equal(t, len(rdataBytes), offset) 80 + assert.Equal(t, nsName, ns.NSDName) 36 81 37 - encodeBuf := []byte{} 38 - encodedRdata := txtToEncode.Encode(encodeBuf, nil) 82 + _, err = ns.Decode(rdataBytes[:len(rdataBytes)-2], 0, len(rdataBytes)-2) 83 + assert.Error(t, err) 84 + assert.True(t, errors.Is(err, &BufferOverflowError{})) 85 + assert.ErrorContains(t, err, "NS record: failed to decode NSDName") 39 86 40 - assert.Equal(t, expectedEncodedRdata, encodedRdata, "Encoded TXT RDATA is incorrect") 87 + nsEncode := &NS{NSDName: nsName} 88 + encoded := encodeRData(t, nsEncode) 89 + assert.Equal(t, rdataBytes, encoded) 41 90 } 42 91 43 - func TestHINFORecordRFCCompliance(t *testing.T) { 44 - rdataBytes := []byte{0x03, 'C', 'P', 'U', 0x02, 'O', 'S'} 45 - rdlength := uint16(len(rdataBytes)) 92 + func TestCNAMERecord(t *testing.T) { 93 + cname := "target.example.com" 94 + offsets := make(map[string]uint16) 95 + rdataBytes, _ := encodeDomain([]byte{}, cname, &offsets) 96 + c := &CNAME{} 97 + 98 + offset, err := c.Decode(rdataBytes, 0, len(rdataBytes)) 99 + assert.NoError(t, err) 100 + assert.Equal(t, len(rdataBytes), offset) 101 + assert.Equal(t, cname, c.CName) 46 102 47 - buf := []byte{0x00} 48 - buf = binary.BigEndian.AppendUint16(buf, uint16(HINFOType)) 49 - buf = binary.BigEndian.AppendUint16(buf, uint16(IN)) 50 - buf = binary.BigEndian.AppendUint32(buf, 3600) 51 - buf = binary.BigEndian.AppendUint16(buf, rdlength) 52 - buf = append(buf, rdataBytes...) 103 + _, err = c.Decode(rdataBytes[:5], 0, 5) 104 + assert.Error(t, err) 105 + assert.ErrorContains(t, err, "CNAME record") 53 106 54 - rr := &ResourceRecord{} 55 - offset, err := rr.Decode(buf, 0) 56 - require.NoError(t, err) 57 - assert.Equal(t, len(buf), offset) 58 - require.IsType(t, &HINFO{}, rr.RData) 107 + cEncode := &CNAME{CName: cname} 108 + encoded := encodeRData(t, cEncode) 109 + assert.Equal(t, rdataBytes, encoded) 110 + } 59 111 60 - hinfoData := rr.RData.(*HINFO) 112 + func TestSOARecord(t *testing.T) { 113 + mname := "ns.example.com" 114 + rname := "admin.example.com" 115 + serial := uint32(2023010101) 116 + refresh := uint32(7200) 117 + retry := uint32(3600) 118 + expire := uint32(1209600) 119 + minimum := uint32(3600) 61 120 62 - assert.Equal(t, "CPU", hinfoData.CPU, "Decoded HINFO CPU does not match") 63 - assert.Equal(t, "OS", hinfoData.OS, "Decoded HINFO OS does not match") 121 + soaEncode := &SOA{MName: mname, RName: rname, Serial: serial, Refresh: refresh, Retry: retry, Expire: expire, Minimum: minimum} 122 + rdataBytes := encodeRData(t, soaEncode) 123 + soa := &SOA{} 64 124 65 - hinfoToEncode := &HINFO{CPU: "Intel", OS: "Linux"} 66 - expectedEncodedRdata := []byte{0x05, 'I', 'n', 't', 'e', 'l', 0x05, 'L', 'i', 'n', 'u', 'x'} 67 - encodeBuf := []byte{} 68 - encodedRdata := hinfoToEncode.Encode(encodeBuf, nil) 125 + offset, err := soa.Decode(rdataBytes, 0, len(rdataBytes)) 126 + assert.NoError(t, err) 127 + assert.Equal(t, len(rdataBytes), offset) 128 + assert.Equal(t, *soaEncode, *soa) 129 + 130 + _, err = soa.Decode(rdataBytes[:len(rdataBytes)-5], 0, len(rdataBytes)-5) 131 + assert.Error(t, err) 132 + assert.ErrorContains(t, err, "SOA record:") 133 + 134 + nameOffset := make(map[string]uint16) 135 + mnameBytes, _ := encodeDomain([]byte{}, mname, &nameOffset) 136 + rnameBytes, _ := encodeDomain([]byte{}, rname, &nameOffset) 137 + shortRdataBytes := append(mnameBytes, rnameBytes...) 138 + _, err = soa.Decode(shortRdataBytes, 0, len(shortRdataBytes)) 139 + assert.Error(t, err) 140 + assert.ErrorContains(t, err, "SOA record") 141 + } 142 + 143 + func TestPTRRecord(t *testing.T) { 144 + ptrName := "host.example.com" 145 + offsets := make(map[string]uint16) 146 + rdataBytes, _ := encodeDomain([]byte{}, ptrName, &offsets) 147 + ptr := &PTR{} 148 + 149 + offset, err := ptr.Decode(rdataBytes, 0, len(rdataBytes)) 150 + assert.NoError(t, err) 151 + assert.Equal(t, len(rdataBytes), offset) 152 + assert.Equal(t, ptrName, ptr.PTRDName) 153 + 154 + _, err = ptr.Decode(rdataBytes[:3], 0, 3) 155 + assert.Error(t, err) 156 + assert.ErrorContains(t, err, "PTR record: failed to decode PTRDName") 157 + 158 + ptrEncode := &PTR{PTRDName: ptrName} 159 + encoded := encodeRData(t, ptrEncode) 160 + assert.Equal(t, rdataBytes, encoded) 161 + } 162 + 163 + func TestMXRecord(t *testing.T) { 164 + preference := uint16(10) 165 + exchange := "mail.example.com" 166 + 167 + mxEncode := &MX{Preference: preference, Exchange: exchange} 168 + rdataBytes := encodeRData(t, mxEncode) 169 + mx := &MX{} 170 + 171 + offset, err := mx.Decode(rdataBytes, 0, len(rdataBytes)) 172 + assert.NoError(t, err) 173 + assert.Equal(t, len(rdataBytes), offset) 174 + assert.Equal(t, *mxEncode, *mx) 175 + 176 + _, err = mx.Decode([]byte{0}, 0, 1) 177 + assert.Error(t, err) 178 + assert.ErrorContains(t, err, "MX record") 179 + 180 + buf := make([]byte, 2) 181 + binary.BigEndian.PutUint16(buf, preference) 182 + buf = append(buf, []byte{4, 'm', 'a'}...) 183 + _, err = mx.Decode(buf, 0, len(buf)) 184 + assert.Error(t, err) 185 + assert.ErrorContains(t, err, "MX record: failed to decode Exchange") 186 + } 187 + 188 + func TestTXTRecord(t *testing.T) { 189 + txtData := []string{"abc", "def"} 190 + txtEncode := &TXT{TxtData: txtData} 191 + rdataBytes := encodeRData(t, txtEncode) 192 + txt := &TXT{} 193 + 194 + offset, err := txt.Decode(rdataBytes, 0, len(rdataBytes)) 195 + require.NoError(t, err, "TXT Decode failed") 196 + assert.Equal(t, len(rdataBytes), offset) 197 + assert.Equal(t, txtData, txt.TxtData, "Decoded TXT data mismatch") 198 + 199 + txtDataEmpty := []string{""} 200 + txtEncodeEmpty := &TXT{TxtData: txtDataEmpty} 201 + rdataBytesEmpty := encodeRData(t, txtEncodeEmpty) 202 + offset, err = txt.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty)) 203 + require.NoError(t, err, "TXT Decode with empty string failed") 204 + assert.Equal(t, len(rdataBytesEmpty), offset) 205 + assert.Equal(t, txtDataEmpty, txt.TxtData) 206 + 207 + txtDataMulti := []string{"v=spf1", "include:_spf.google.com", "~all"} 208 + txtEncodeMulti := &TXT{TxtData: txtDataMulti} 209 + rdataBytesMulti := encodeRData(t, txtEncodeMulti) 210 + offset, err = txt.Decode(rdataBytesMulti, 0, len(rdataBytesMulti)) 211 + require.NoError(t, err, "TXT Decode with multiple strings failed") 212 + assert.Equal(t, len(rdataBytesMulti), offset) 213 + assert.Equal(t, txtDataMulti, txt.TxtData) 214 + 215 + _, err = txt.Decode([]byte{}, 0, 0) 216 + assert.NoError(t, err) 217 + 218 + _, err = txt.Decode([]byte{5, 'd', 'a', 't'}, 0, 4) 219 + assert.Error(t, err) 220 + assert.ErrorContains(t, err, "TXT record: string segment length 5 exceeds RDLENGTH boundary 4") 221 + 222 + encoded := encodeRData(t, txtEncode) 223 + assert.Equal(t, rdataBytes, encoded) 224 + } 225 + 226 + func TestHINFORecord(t *testing.T) { 227 + cpu := "Intel" 228 + os := "Linux" 229 + hinfoEncode := &HINFO{CPU: cpu, OS: os} 230 + rdataBytes := encodeRData(t, hinfoEncode) 231 + hinfo := &HINFO{} 232 + 233 + offset, err := hinfo.Decode(rdataBytes, 0, len(rdataBytes)) 234 + require.NoError(t, err, "HINFO Decode failed") 235 + assert.Equal(t, len(rdataBytes), offset) 236 + assert.Equal(t, cpu, hinfo.CPU) 237 + assert.Equal(t, os, hinfo.OS) 238 + 239 + hinfoEncodeEmpty := &HINFO{CPU: "", OS: ""} 240 + rdataBytesEmpty := encodeRData(t, hinfoEncodeEmpty) 241 + offset, err = hinfo.Decode(rdataBytesEmpty, 0, len(rdataBytesEmpty)) 242 + require.NoError(t, err, "HINFO Decode with empty strings failed") 243 + assert.Equal(t, len(rdataBytesEmpty), offset) 244 + assert.Equal(t, "", hinfo.CPU) 245 + assert.Equal(t, "", hinfo.OS) 246 + 247 + _, err = hinfo.Decode([]byte{}, 0, 0) 248 + assert.Error(t, err) 249 + assert.ErrorContains(t, err, "HINFO record:") 250 + 251 + _, err = hinfo.Decode([]byte{5, 'I', 'n'}, 0, 3) 252 + assert.Error(t, err) 253 + assert.ErrorContains(t, err, "buffer overflow:") 254 + 255 + _, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l'}, 0, 6) 256 + assert.Error(t, err) 257 + assert.ErrorContains(t, err, "HINFO record:") 258 + 259 + _, err = hinfo.Decode([]byte{5, 'I', 'n', 't', 'e', 'l', 5, 'L', 'i'}, 0, 9) 260 + assert.Error(t, err) 261 + assert.ErrorContains(t, err, "buffer overflow:") 262 + 263 + extraData := append(rdataBytes, 0xFF) 264 + _, err = hinfo.Decode(extraData, 0, len(extraData)) 265 + assert.Error(t, err) 266 + assert.ErrorContains(t, err, "HINFO record:") 69 267 70 - assert.Equal(t, expectedEncodedRdata, encodedRdata, "Encoded HINFO RDATA is incorrect") 268 + _, err = hinfo.Decode([]byte{10, 'a', 'b', 'c'}, 0, 4) 269 + assert.Error(t, err) 270 + assert.ErrorContains(t, err, "buffer overflow:") 71 271 } 72 272 73 - func TestWKSRecordDecoding(t *testing.T) { 273 + func TestWKSRecord(t *testing.T) { 74 274 addr := net.ParseIP("192.168.1.1").To4() 75 275 proto := byte(6) 76 276 bitmap := []byte{0x01, 0x80} 77 - rdataBytes := append(addr, proto) 78 - rdataBytes = append(rdataBytes, bitmap...) 79 - rdlength := uint16(len(rdataBytes)) 80 277 81 - buf := []byte{0x00} 82 - buf = binary.BigEndian.AppendUint16(buf, uint16(WKSType)) 83 - buf = binary.BigEndian.AppendUint16(buf, uint16(IN)) 84 - buf = binary.BigEndian.AppendUint32(buf, 3600) 85 - buf = binary.BigEndian.AppendUint16(buf, rdlength) 86 - buf = append(buf, rdataBytes...) 278 + wksEncode := &WKS{Address: addr, Protocol: proto, BitMap: bitmap} 279 + rdataBytes := encodeRData(t, wksEncode) 280 + wks := &WKS{} 87 281 88 - rr := &ResourceRecord{} 89 - offset, err := rr.Decode(buf, 0) 90 - require.NoError(t, err) 91 - assert.Equal(t, len(buf), offset) 92 - require.IsType(t, &WKS{}, rr.RData) 282 + offset, err := wks.Decode(rdataBytes, 0, len(rdataBytes)) 283 + require.NoError(t, err, "WKS Decode failed") 284 + assert.Equal(t, len(rdataBytes), offset) 285 + assert.Equal(t, addr, wks.Address.To4()) 286 + assert.Equal(t, proto, wks.Protocol) 287 + assert.Equal(t, bitmap, wks.BitMap) 93 288 94 - wksData := rr.RData.(*WKS) 289 + wksEncodeNoBitmap := &WKS{Address: addr, Protocol: proto, BitMap: []byte{}} 290 + rdataBytesNoBitmap := encodeRData(t, wksEncodeNoBitmap) 291 + wks = &WKS{} 292 + offset, err = wks.Decode(rdataBytesNoBitmap, 0, len(rdataBytesNoBitmap)) 293 + require.NoError(t, err, "WKS Decode without bitmap failed") 294 + assert.Equal(t, len(rdataBytesNoBitmap), offset) 295 + assert.Equal(t, addr, wks.Address.To4()) 296 + assert.Equal(t, proto, wks.Protocol) 297 + assert.Empty(t, wks.BitMap) 95 298 96 - assert.Equal(t, addr, wksData.Address.To4()) 97 - assert.Equal(t, proto, wksData.Protocol) 98 - assert.Equal(t, bitmap, wksData.BitMap) 299 + _, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 4) 300 + assert.Error(t, err) 301 + assert.ErrorContains(t, err, "WKS record: RDLENGTH 4 is too short") 302 + 303 + _, err = wks.Decode([]byte{1, 2, 3}, 0, 5) 304 + assert.Error(t, err) 305 + assert.ErrorContains(t, err, "WKS record: failed to read address") 306 + 307 + _, err = wks.Decode([]byte{1, 2, 3, 4}, 0, 5) 308 + assert.Error(t, err) 309 + assert.ErrorContains(t, err, "WKS record: failed to read protocol") 310 + 311 + _, err = wks.Decode([]byte{1, 2, 3, 4, 6, 0x01}, 0, 7) 312 + assert.Error(t, err) 313 + assert.ErrorContains(t, err, "WKS record: failed to read bitmap") 99 314 } 100 315 101 - func TestSOADecodeWithCompression(t *testing.T) { 102 - input := []byte{0x69, 0x7b, 0x81, 0x83, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0xf, 0x6e, 0x6f, 0x77, 0x61, 0x79, 0x74, 0x68, 0x69, 0x73, 0x65, 0x78, 0x69, 0x73, 0x74, 0x73, 0x3, 0x63, 0x6f, 0x6d, 0x0, 0x0, 0x1, 0x0, 0x1, 0xc0, 0x1c, 0x0, 0x6, 0x0, 0x1, 0x0, 0x0, 0x3, 0x84, 0x0, 0x3d, 0x1, 0x61, 0xc, 0x67, 0x74, 0x6c, 0x64, 0x2d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x3, 0x6e, 0x65, 0x74, 0x0, 0x5, 0x6e, 0x73, 0x74, 0x6c, 0x64, 0xc, 0x76, 0x65, 0x72, 0x69, 0x73, 0x69, 0x67, 0x6e, 0x2d, 0x67, 0x72, 0x73, 0xc0, 0x1c, 0x67, 0xaa, 0xc5, 0x6b, 0x0, 0x0, 0x7, 0x8, 0x0, 0x0, 0x3, 0x84, 0x0, 0x9, 0x3a, 0x80, 0x0, 0x0, 0x3, 0x84} 316 + func TestReservedRecord(t *testing.T) { 317 + rdataBytes := []byte{0xDE, 0xAD, 0xBE, 0xEF} 318 + r := &Reserved{} 103 319 104 - msg := &Message{} 105 - err := msg.Decode(input) 320 + offset, err := r.Decode(rdataBytes, 0, len(rdataBytes)) 106 321 assert.NoError(t, err) 322 + assert.Equal(t, len(rdataBytes), offset) 323 + assert.Equal(t, rdataBytes, r.Bytes) 107 324 108 - assert.Equal(t, 1, len(msg.Authority)) 325 + _, err = r.Decode(rdataBytes[:2], 0, 4) 326 + assert.Error(t, err) 327 + assert.ErrorContains(t, err, "reserved record: failed to read data") 328 + 329 + rEncode := &Reserved{Bytes: rdataBytes} 330 + encoded := encodeRData(t, rEncode) 331 + assert.Equal(t, rdataBytes, encoded) 332 + 333 + rEncodeNil := &Reserved{Bytes: nil} 334 + encodedNil := encodeRData(t, rEncodeNil) 335 + assert.Empty(t, encodedNil) 336 + } 337 + 338 + func TestResourceRecordDecode(t *testing.T) { 339 + tests := []struct { 340 + name string 341 + input []byte 342 + expectedRR *ResourceRecord 343 + wantErr bool 344 + wantErrType error 345 + wantErrMsg string 346 + }{ 347 + { 348 + name: "Valid A record", 349 + input: buildRRBytes(t, "a.com", AType, IN, 60, []byte{1, 1, 1, 1}), 350 + expectedRR: &ResourceRecord{ 351 + Name: "a.com", RType: AType, RClass: IN, TTL: 60, RDLength: 4, RData: &A{net.IP{1, 1, 1, 1}}, 352 + }, 353 + }, 354 + { 355 + name: "Valid TXT record", 356 + input: buildRRBytes(t, "b.org", TXTType, IN, 300, encodeRData(t, &TXT{[]string{"hello", "world"}})), 357 + expectedRR: &ResourceRecord{ 358 + Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RDLength: 12, RData: &TXT{[]string{"hello", "world"}}, 359 + }, 360 + }, 361 + { 362 + name: "Unknown record type", 363 + input: buildRRBytes(t, "c.net", DNSType(9999), IN, 10, []byte{0xca, 0xfe}), 364 + expectedRR: &ResourceRecord{ 365 + Name: "c.net", RType: DNSType(9999), RClass: IN, TTL: 10, RDLength: 2, RData: &Reserved{[]byte{0xca, 0xfe}}, 366 + }, 367 + }, 368 + { 369 + name: "Truncated name", 370 + input: []byte{3, 'a', 'b'}, 371 + wantErr: true, 372 + wantErrType: &BufferOverflowError{}, 373 + wantErrMsg: "rr decode:", 374 + }, 375 + { 376 + name: "Truncated type", 377 + input: buildRRBytes(t, "d.com", AType, IN, 60, []byte{1})[:5], 378 + wantErr: true, 379 + wantErrType: &BufferOverflowError{}, 380 + wantErrMsg: "rr decode:", 381 + }, 382 + { 383 + name: "Truncated RDATA section", 384 + input: buildRRBytes(t, "e.com", AType, IN, 60, []byte{1, 2, 3, 4})[:15], 385 + wantErr: true, 386 + wantErrType: &BufferOverflowError{}, 387 + wantErrMsg: "rr decode:", 388 + }, 389 + { 390 + name: "RDLENGTH mismatch (claims longer than buffer)", 391 + input: func() []byte { 392 + buf := buildRRBytes(t, "f.com", AType, IN, 60, []byte{1, 2, 3, 4}) 393 + binary.BigEndian.PutUint16(buf[10:12], 10) 394 + return buf[:14] 395 + }(), 396 + wantErr: true, 397 + wantErrType: &BufferOverflowError{}, 398 + wantErrMsg: "rr decode:", 399 + }, 400 + { 401 + name: "RDLENGTH mismatch (RData decoder consumes less)", 402 + input: func() []byte { 403 + rdataBytes := encodeRData(t, &TXT{[]string{"short"}}) 404 + buf := buildRRBytes(t, "g.com", TXTType, IN, 60, rdataBytes) 405 + nameLen := len(buf) - 10 - len(rdataBytes) 406 + rdlenPos := nameLen + 8 407 + binary.BigEndian.PutUint16(buf[rdlenPos:rdlenPos+2], uint16(len(rdataBytes)+5)) 408 + return buf 409 + }(), 410 + wantErr: true, 411 + wantErrMsg: "rr decode:", 412 + }, 413 + } 109 414 110 - rr := msg.Authority[0] 111 - assert.Equal(t, DNSType(6), rr.RType) 112 - assert.Equal(t, DNSClass(1), rr.RClass) 113 - assert.Equal(t, uint32(900), rr.TTL) 114 - assert.Equal(t, uint16(61), rr.RDLength) 415 + for _, tt := range tests { 416 + t.Run(tt.name, func(t *testing.T) { 417 + rr := &ResourceRecord{} 418 + offset, err := rr.Decode(tt.input, 0) 115 419 116 - soa, ok := msg.Authority[0].RData.(*SOA) 117 - assert.True(t, ok) 420 + if tt.wantErr { 421 + assert.Error(t, err) 422 + if tt.wantErrType != nil { 423 + assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err) 424 + } 425 + if tt.wantErrMsg != "" { 426 + assert.ErrorContains(t, err, tt.wantErrMsg) 427 + } 428 + } else { 429 + assert.NoError(t, err) 430 + assert.Equal(t, len(tt.input), offset, "Offset should match input length") 431 + assert.Equal(t, tt.expectedRR.Name, rr.Name) 432 + assert.Equal(t, tt.expectedRR.RType, rr.RType) 433 + assert.Equal(t, tt.expectedRR.RClass, rr.RClass) 434 + assert.Equal(t, tt.expectedRR.TTL, rr.TTL) 435 + assert.Equal(t, tt.expectedRR.RDLength, rr.RDLength) 436 + assert.Equal(t, tt.expectedRR.RData, rr.RData) 437 + } 438 + }) 439 + } 440 + } 118 441 119 - assert.Equal(t, "a.gtld-servers.net", soa.MName) 120 - assert.Equal(t, "nstld.verisign-grs.com", soa.RName) 121 - assert.Equal(t, uint32(1739244907), soa.Serial) 122 - assert.Equal(t, uint32(1800), soa.Refresh) 123 - assert.Equal(t, uint32(900), soa.Retry) 124 - assert.Equal(t, uint32(604800), soa.Expire) 125 - assert.Equal(t, uint32(900), soa.Minimum) 442 + func TestResourceRecordEncode(t *testing.T) { 443 + tests := []struct { 444 + name string 445 + rr *ResourceRecord 446 + expectedLen int 447 + wantErr bool 448 + wantErrType error 449 + wantErrMsg string 450 + }{ 451 + { 452 + name: "Valid A record", 453 + rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}}, 454 + }, 455 + { 456 + name: "Valid TXT record", 457 + rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{"hello", "world"}}}, 458 + }, 459 + { 460 + name: "Encode fail - Invalid Name", 461 + rr: &ResourceRecord{Name: "a..b", RType: AType, RClass: IN, TTL: 60, RData: &A{net.IP{1, 1, 1, 1}}}, 462 + wantErr: true, 463 + wantErrType: &InvalidLabelError{}, 464 + wantErrMsg: "rr encode: failed to encode record name a..b", 465 + }, 466 + { 467 + name: "Encode fail - Invalid RData (A record)", 468 + rr: &ResourceRecord{Name: "a.com", RType: AType, RClass: IN, TTL: 60, RData: &A{net.ParseIP("::1")}}, 469 + wantErr: true, 470 + wantErrMsg: "rr encode: failed to encode RData for a.com (A): A record: cannot encode non-IPv4 address", 471 + }, 472 + { 473 + name: "Encode fail - Invalid RData (TXT record)", 474 + rr: &ResourceRecord{Name: "b.org", RType: TXTType, RClass: IN, TTL: 300, RData: &TXT{[]string{string(make([]byte, 256))}}}, 475 + wantErr: true, 476 + wantErrMsg: "rr encode: failed to encode RData for b.org (TXT): TXT record: string segment length 256 exceeds maximum 255", 477 + }, 478 + } 126 479 127 - encoded := msg.Encode() 128 - assert.Equal(t, input, encoded) 480 + for _, tt := range tests { 481 + t.Run(tt.name, func(t *testing.T) { 482 + offsets := make(map[string]uint16) 483 + encodedBytes, err := tt.rr.Encode([]byte{}, &offsets) 484 + 485 + if tt.wantErr { 486 + assert.Error(t, err) 487 + if tt.wantErrType != nil { 488 + assert.True(t, errors.Is(err, tt.wantErrType), "Error type mismatch. Got %T", err) 489 + } 490 + if tt.wantErrMsg != "" { 491 + assert.ErrorContains(t, err, tt.wantErrMsg) 492 + } 493 + } else { 494 + assert.NoError(t, err) 495 + assert.NotEmpty(t, encodedBytes) 496 + 497 + decodedRR := &ResourceRecord{} 498 + offset, decodeErr := decodedRR.Decode(encodedBytes, 0) 499 + assert.NoError(t, decodeErr, "Failed to decode back encoded RR") 500 + if decodeErr == nil { 501 + assert.Equal(t, len(encodedBytes), offset, "Decoded offset mismatch") 502 + assert.Equal(t, tt.rr.Name, decodedRR.Name) 503 + assert.Equal(t, tt.rr.RType, decodedRR.RType) 504 + assert.Equal(t, tt.rr.RClass, decodedRR.RClass) 505 + assert.Equal(t, tt.rr.TTL, decodedRR.TTL) 506 + if tt.rr.RData == nil { 507 + assert.IsType(t, &Reserved{}, decodedRR.RData, "Nil RData should decode as Reserved") 508 + assert.Empty(t, decodedRR.RData.(*Reserved).Bytes, "Nil RData should decode as empty Reserved") 509 + assert.Equal(t, uint16(0), decodedRR.RDLength, "Nil RData should have RDLength 0") 510 + } else { 511 + assert.Equal(t, tt.rr.RData, decodedRR.RData, "RData mismatch after round trip") 512 + assert.NotEqual(t, uint16(0), decodedRR.RDLength, "Non-nil RData should have non-zero RDLength") 513 + } 514 + } 515 + } 516 + }) 517 + } 129 518 }
+1 -1
types.go
··· 213 213 // *map[string]uint8 - map containing labels and offsets for domain name compression 214 214 type ResourceRecordData interface { 215 215 Decode([]byte, int, int) (int, error) 216 - Encode([]byte, *map[string]uint16) []byte 216 + Encode([]byte, *map[string]uint16) ([]byte, error) 217 217 String() string 218 218 } 219 219
+16 -16
utils.go
··· 6 6 7 7 // getU8 returns the first byte from a byte array at offset. 8 8 func getU8(buf []byte, offset int) (uint8, int, error) { 9 - next_offset := offset + 1 10 - if next_offset > len(buf) { 11 - return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset} 9 + nextOffset := offset + 1 10 + if nextOffset > len(buf) { 11 + return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset} 12 12 } 13 13 14 - return buf[offset], next_offset, nil 14 + return buf[offset], nextOffset, nil 15 15 } 16 16 17 17 // getU16 returns the bigEndian uint16 from a byte array at offset. 18 18 func getU16(buf []byte, offset int) (uint16, int, error) { 19 - next_offset := offset + 2 20 - if next_offset > len(buf) { 21 - return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset} 19 + nextOffset := offset + 2 20 + if nextOffset > len(buf) { 21 + return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset} 22 22 } 23 23 24 - return binary.BigEndian.Uint16(buf[offset:]), next_offset, nil 24 + return binary.BigEndian.Uint16(buf[offset:]), nextOffset, nil 25 25 } 26 26 27 27 // getU32 returns the bigEndian uint32 from a byte array at offset. 28 28 func getU32(buf []byte, offset int) (uint32, int, error) { 29 - next_offset := offset + 4 30 - if next_offset > len(buf) { 31 - return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset} 29 + nextOffset := offset + 4 30 + if nextOffset > len(buf) { 31 + return 0, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset} 32 32 } 33 33 34 - return binary.BigEndian.Uint32(buf[offset:]), next_offset, nil 34 + return binary.BigEndian.Uint32(buf[offset:]), nextOffset, nil 35 35 } 36 36 37 37 // getSlice returns a slice of bytes from a byte array at an offset and of length. 38 38 func getSlice(buf []byte, offset int, length int) ([]byte, int, error) { 39 - next_offset := offset + length 40 - if next_offset > len(buf) { 41 - return nil, len(buf), &BufferOverflowError{Length: len(buf), Offset: next_offset} 39 + nextOffset := offset + length 40 + if nextOffset > len(buf) { 41 + return nil, len(buf), &BufferOverflowError{Length: len(buf), Offset: nextOffset} 42 42 } 43 43 44 - return buf[offset:next_offset], next_offset, nil 44 + return buf[offset:nextOffset], nextOffset, nil 45 45 }