From e0239c1698196e25361c1c1c3a58029a02507f86 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Wed, 30 Jul 2025 15:59:45 +0300 Subject: [PATCH 01/17] added another way to configure arpspoofer, added some funcs to network --- arpspoof/arpspoof.go | 43 ++++++++++++++++++++++++++++++++++++++++++- cmd/marpspoof/main.go | 2 +- cmd/mshark/cli.go | 2 +- network/network.go | 19 +++++++++++++++++-- version.go | 2 +- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/arpspoof/arpspoof.go b/arpspoof/arpspoof.go index 7f53b8e..3a9de15 100644 --- a/arpspoof/arpspoof.go +++ b/arpspoof/arpspoof.go @@ -23,7 +23,6 @@ import ( ) const ( - protocolARP = 0x0806 unixEthPAll = 0x03 ) @@ -32,6 +31,9 @@ var ( probeTargetsInterval = 60 * time.Second refreshARPTableInterval = 15 * time.Second arpSpoofTargetsInterval = 1 * time.Second + errARPSpoofConfig = fmt.Errorf( + `failed parsing arp options. Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true"`, + ) ) type Packet struct { @@ -48,6 +50,45 @@ type ARPSpoofConfig struct { Debug bool } +// NewARPSpoofConfig creates ARPSpoofConfig from a list of options separated by semicolon and logger. +// +// Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true;interface eth0;gateway 192.168.1.1"`. +// All fields in configuration string are optional. +func NewARPSpoofConfig(s string, logger *zerolog.Logger) (*ARPSpoofConfig, error) { + asc := &ARPSpoofConfig{Logger: logger} + for opt := range strings.SplitSeq(strings.ToLower(s), ";") { + keyval := strings.SplitN(strings.Trim(opt, " "), " ", 2) + if len(keyval) < 2 { + return nil, errARPSpoofConfig + } + key := keyval[0] + val := keyval[1] + switch key { + case "targets": + asc.Targets = val + case "interface": + asc.Interface = val + case "gateway": + gateway, err := netip.ParseAddr(val) + if err != nil { + return nil, err + } + asc.Gateway = &gateway + case "fullduplex": + if val == "true" { + asc.FullDuplex = true + } + case "debug": + if val == "true" { + asc.Debug = true + } + default: + return nil, errARPSpoofConfig + } + } + return asc, nil +} + type ARPTable struct { sync.RWMutex Ifname string diff --git a/cmd/marpspoof/main.go b/cmd/marpspoof/main.go index f0ed072..9624247 100644 --- a/cmd/marpspoof/main.go +++ b/cmd/marpspoof/main.go @@ -38,7 +38,7 @@ func root(args []string) error { flags.BoolVar(&conf.Debug, "d", false, "Enable debug logging") nocolor := flags.Bool("nocolor", false, "Disable colored output") flags.BoolFunc("I", "Display list of interfaces and exit.", func(flagValue string) error { - if err := network.DisplayInterfaces(); err != nil { + if err := network.DisplayInterfaces(false); err != nil { fmt.Fprintf(os.Stderr, "%s: %v\n", app, err) os.Exit(2) } diff --git a/cmd/mshark/cli.go b/cmd/mshark/cli.go index ad989d0..09585e9 100644 --- a/cmd/mshark/cli.go +++ b/cmd/mshark/cli.go @@ -89,7 +89,7 @@ func root(args []string) error { packetBuffer := flags.Int("b", 8192, "The maximum size of packet queue.") flags.StringVar(&conf.Expr, "e", "", `BPF filter expression. Example: "ip proto tcp".`) flags.BoolFunc("D", "Display list of interfaces and exit.", func(flagValue string) error { - if err := network.DisplayInterfaces(); err != nil { + if err := network.DisplayInterfaces(true); err != nil { fmt.Fprintf(os.Stderr, "%s: %v\n", app, err) os.Exit(2) } diff --git a/network/network.go b/network/network.go index a9a1bf7..091bef6 100644 --- a/network/network.go +++ b/network/network.go @@ -42,7 +42,7 @@ func InterfaceByName(name string) (*net.Interface, error) { return in, nil } -func DisplayInterfaces() error { +func DisplayInterfaces(includeAny bool) error { w := new(tabwriter.Writer) w.Init(os.Stdout, 0, 0, 2, ' ', tabwriter.TabIndent) ifaces, err := net.Interfaces() @@ -50,7 +50,9 @@ func DisplayInterfaces() error { return fmt.Errorf("failed to get network interfaces: %v", err) } fmt.Fprintln(w, "Index\tName\tFlags") - fmt.Fprintln(w, "0\tany\tUP") + if includeAny { + fmt.Fprintln(w, "0\tany\tUP") + } for _, iface := range ifaces { fmt.Fprintf(w, "%d\t%s\t%s\n", iface.Index, iface.Name, strings.ToUpper(iface.Flags.String())) } @@ -146,3 +148,16 @@ func GetIPv4PrefixFromInterface(iface *net.Interface) (netip.Prefix, error) { } return netip.Prefix{}, fmt.Errorf("no IPv4 prefix found") } + +func IsLocalAddress(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + ip := net.ParseIP(host) + if ip != nil { + return ip.IsLoopback() + } + host = strings.ToLower(host) + return strings.HasSuffix(host, ".local") || host == "localhost" +} diff --git a/version.go b/version.go index 1757501..3ca5177 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.9" +const Version string = "mshark v0.0.10" From cb79f9669f4a89892ee2a05b582acf327de72a56 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Mon, 4 Aug 2025 07:35:22 +0300 Subject: [PATCH 02/17] parsing now creates new buffer for each packet --- layers/arp.go | 20 +++++++++++--------- layers/dns.go | 21 +++++++++++---------- layers/ethernet.go | 10 ++++++---- layers/ftp.go | 4 +++- layers/http.go | 10 ++++++---- layers/icmp.go | 10 ++++++---- layers/icmpv6.go | 10 ++++++---- layers/ipv4.go | 28 +++++++++++++++------------- layers/ipv6.go | 16 +++++++++------- layers/snmp.go | 4 +++- layers/ssh.go | 26 ++++++++++++++------------ layers/tcp.go | 22 ++++++++++++---------- layers/tls.go | 24 +++++++++++++----------- layers/udp.go | 12 +++++++----- version.go | 2 +- 15 files changed, 123 insertions(+), 96 deletions(-) diff --git a/layers/arp.go b/layers/arp.go index ec4e789..0532662 100644 --- a/layers/arp.go +++ b/layers/arp.go @@ -148,30 +148,32 @@ func (ap *ARPPacket) UnmarshalBinary(data []byte) error { if len(data) < headerSizeARP { return fmt.Errorf("minimum header size for ARP is %d bytes, got %d bytes", headerSizeARP, len(data)) } - ap.HardwareType = binary.BigEndian.Uint16(data[0:2]) - ap.ProtocolType = binary.BigEndian.Uint16(data[2:4]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + ap.HardwareType = binary.BigEndian.Uint16(buf[0:2]) + ap.ProtocolType = binary.BigEndian.Uint16(buf[2:4]) ap.ProtocolTypeDesc = ptypedesc(ap.ProtocolType) if ap.ProtocolTypeDesc == "Unknown" { return fmt.Errorf("unknown protocol type") } - ap.Hlen = data[4] - ap.Plen = data[5] - op := Operation(binary.BigEndian.Uint16(data[6:8])) + ap.Hlen = buf[4] + ap.Plen = buf[5] + op := Operation(binary.BigEndian.Uint16(buf[6:8])) opdesc := opdesc(op) if opdesc == "Unknown" { return fmt.Errorf("unknown operation") } ap.Op = &ARPOperation{Val: op, Desc: opdesc} hoffset := 8 + ap.Hlen - ap.SenderMAC = net.HardwareAddr(data[8:hoffset]) + ap.SenderMAC = net.HardwareAddr(buf[8:hoffset]) poffset := hoffset + ap.Plen var ok bool - ap.SenderIP, ok = netip.AddrFromSlice(data[hoffset:poffset]) + ap.SenderIP, ok = netip.AddrFromSlice(buf[hoffset:poffset]) if !ok { return fmt.Errorf("failed parsing sender IP address") } - ap.TargetMAC = net.HardwareAddr(data[poffset : poffset+ap.Hlen]) - ap.TargetIP, ok = netip.AddrFromSlice(data[poffset+ap.Hlen : poffset+ap.Hlen+ap.Plen]) + ap.TargetMAC = net.HardwareAddr(buf[poffset : poffset+ap.Hlen]) + ap.TargetIP, ok = netip.AddrFromSlice(buf[poffset+ap.Hlen : poffset+ap.Hlen+ap.Plen]) if !ok { return fmt.Errorf("failed parsing target IP address") } diff --git a/layers/dns.go b/layers/dns.go index 2980f5a..3d9f9f0 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -241,14 +241,16 @@ func (d *DNSMessage) Parse(data []byte) error { if len(data) < headerSizeDNS { return fmt.Errorf("minimum header size for DNS is %d bytes, got %d bytes", headerSizeDNS, len(data)) } - d.TransactionID = binary.BigEndian.Uint16(data[0:2]) - d.Flags = newDNSFlags(binary.BigEndian.Uint16(data[2:4])) - d.QDCount = binary.BigEndian.Uint16(data[4:6]) - d.ANCount = binary.BigEndian.Uint16(data[6:8]) - d.NSCount = binary.BigEndian.Uint16(data[8:10]) - d.ARCount = binary.BigEndian.Uint16(data[10:headerSizeDNS]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + d.TransactionID = binary.BigEndian.Uint16(buf[0:2]) + d.Flags = newDNSFlags(binary.BigEndian.Uint16(buf[2:4])) + d.QDCount = binary.BigEndian.Uint16(buf[4:6]) + d.ANCount = binary.BigEndian.Uint16(buf[6:8]) + d.NSCount = binary.BigEndian.Uint16(buf[8:10]) + d.ARCount = binary.BigEndian.Uint16(buf[10:headerSizeDNS]) var tail []byte - payload := data[headerSizeDNS:] + payload := buf[headerSizeDNS:] d.Questions = nil d.AnswerRRs = nil d.AuthorityRRs = nil @@ -459,7 +461,7 @@ type RDataSOA struct { } func (d *RDataSOA) String() string { - return fmt.Sprintf(`Primary name server: %s + return fmt.Sprintf(`Primary name server: %s - Responsible authority's mailbox: %s - Serial number: %d - Refresh interval: %d @@ -498,7 +500,6 @@ type RDataAAAA struct { func (d *RDataAAAA) String() string { return fmt.Sprintf("Address: %s", d.Address) - } type RDataOPT struct { @@ -510,7 +511,7 @@ type RDataOPT struct { } func (d *RDataOPT) String() string { - return fmt.Sprintf(`UDP payload size: %d + return fmt.Sprintf(`UDP payload size: %d - Higher bits in extended RCODE: %#02x - EDNS0 version: %d - Z: %d diff --git a/layers/ethernet.go b/layers/ethernet.go index da7b412..cff433a 100644 --- a/layers/ethernet.go +++ b/layers/ethernet.go @@ -92,12 +92,14 @@ func (ef *EthernetFrame) UnmarshalBinary(data []byte) error { if len(data) < headerSizeEthernet { return fmt.Errorf("did not read a complete Ethernet frame, only %d bytes read", len(data)) } - ef.DstMAC = net.HardwareAddr(data[0:6]) - ef.SrcMAC = net.HardwareAddr(data[6:12]) - et := EtherType(binary.BigEndian.Uint16(data[12:14])) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + ef.DstMAC = net.HardwareAddr(buf[0:6]) + ef.SrcMAC = net.HardwareAddr(buf[6:12]) + et := EtherType(binary.BigEndian.Uint16(buf[12:14])) etdesc := ethertypedesc(et) ef.EtherType = &EthernetType{Val: et, Desc: etdesc} - ef.Payload = data[headerSizeEthernet:] + ef.Payload = buf[headerSizeEthernet:] ef.DstVendor = oui.VendorWithMAC(ef.DstMAC) ef.SrcVendor = oui.VendorWithMAC(ef.SrcMAC) return nil diff --git a/layers/ftp.go b/layers/ftp.go index f16b288..368c34c 100644 --- a/layers/ftp.go +++ b/layers/ftp.go @@ -21,9 +21,11 @@ func (f *FTPMessage) Summary() string { } func (f *FTPMessage) Parse(data []byte) error { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) f.summary = nil f.data = nil - sp := bytes.Split(data, crlf) + sp := bytes.Split(buf, crlf) lsp := len(sp) switch { case lsp > 2: diff --git a/layers/http.go b/layers/http.go index 732c3da..2e67c31 100644 --- a/layers/http.go +++ b/layers/http.go @@ -59,13 +59,15 @@ func (h *HTTPMessage) Summary() string { } func (h *HTTPMessage) Parse(data []byte) error { - if !bytes.Contains(data, protohttp10) && !bytes.Contains(data, protohttp11) { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + if !bytes.Contains(buf, protohttp10) && !bytes.Contains(buf, protohttp11) { h.Request = nil h.Response = nil return nil } - reader := bufio.NewReader(bytes.NewReader(data)) - if bytes.HasPrefix(data, protohttp11) || bytes.HasPrefix(data, protohttp10) { + reader := bufio.NewReader(bytes.NewReader(buf)) + if bytes.HasPrefix(buf, protohttp11) || bytes.HasPrefix(buf, protohttp10) { resp, err := http.ReadResponse(reader, nil) if err != nil { return err @@ -73,7 +75,7 @@ func (h *HTTPMessage) Parse(data []byte) error { h.Response = resp h.Request = nil } else { - reader := bufio.NewReader(bytes.NewReader(data)) + reader := bufio.NewReader(bytes.NewReader(buf)) req, err := http.ReadRequest(reader) if err != nil { return err diff --git a/layers/icmp.go b/layers/icmp.go index c0351c8..40231bc 100644 --- a/layers/icmp.go +++ b/layers/icmp.go @@ -46,10 +46,12 @@ func (i *ICMPSegment) Parse(data []byte) error { if len(data) < headerSizeICMP { return fmt.Errorf("minimum header size for ICMP is %d bytes, got %d bytes", headerSizeICMP, len(data)) } - i.Type = data[0] - i.Code = data[1] - i.Checksum = binary.BigEndian.Uint16(data[2:4]) - i.Data = data[headerSizeICMP:] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + i.Type = buf[0] + i.Code = buf[1] + i.Checksum = binary.BigEndian.Uint16(buf[2:4]) + i.Data = buf[headerSizeICMP:] var pLen int switch i.Type { case 0, 3, 5, 8, 11: diff --git a/layers/icmpv6.go b/layers/icmpv6.go index fde107d..c5b4a3a 100644 --- a/layers/icmpv6.go +++ b/layers/icmpv6.go @@ -44,10 +44,12 @@ func (i *ICMPv6Segment) Parse(data []byte) error { if len(data) < headerSizeICMPv6 { return fmt.Errorf("minimum header size for ICMPv6 is %d bytes, got %d bytes", headerSizeICMPv6, len(data)) } - i.Type = data[0] - i.Code = data[1] - i.Checksum = binary.BigEndian.Uint16(data[2:headerSizeICMPv6]) - i.Data = data[headerSizeICMPv6:] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + i.Type = buf[0] + i.Code = buf[1] + i.Checksum = binary.BigEndian.Uint16(buf[2:headerSizeICMPv6]) + i.Data = buf[headerSizeICMPv6:] var pLen int switch i.Type { case 1, 2, 3, 4, 128, 129, 133: diff --git a/layers/ipv4.go b/layers/ipv4.go index 5aba8c4..bada8f0 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -155,38 +155,40 @@ func (p *IPv4Packet) UnmarshalBinary(data []byte) error { if len(data) < headerSizeIPv4 { return fmt.Errorf("minimum header size for IPv4 is %d bytes, got %d bytes", headerSizeIPv4, len(data)) } - versionIHL := data[0] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + versionIHL := buf[0] p.Version = versionIHL >> 4 p.IHL = versionIHL & 15 - dscpECN := data[1] + dscpECN := buf[1] p.DSCP = dscpECN >> 2 p.DSCPDesc = dscpdesc(p.DSCP) p.ECN = dscpECN & 3 - p.TotalLength = binary.BigEndian.Uint16(data[2:4]) - p.Identification = binary.BigEndian.Uint16(data[4:6]) - flagsOffset := binary.BigEndian.Uint16(data[6:8]) + p.TotalLength = binary.BigEndian.Uint16(buf[2:4]) + p.Identification = binary.BigEndian.Uint16(buf[4:6]) + flagsOffset := binary.BigEndian.Uint16(buf[6:8]) flags := uint8(flagsOffset >> 13) p.Flags = NewIPv4Flags(flags) p.FragmentOffset = flagsOffset & (1<<13 - 1) - p.TTL = data[8] - proto := IPProto(data[9]) + p.TTL = buf[8] + proto := IPProto(buf[9]) p.Protocol = &IPv4Proto{Val: proto, Desc: protodesc(proto)} - p.HeaderChecksum = binary.BigEndian.Uint16(data[headerChecksumOffsetIPv4:12]) + p.HeaderChecksum = binary.BigEndian.Uint16(buf[headerChecksumOffsetIPv4:12]) var ok bool - p.SrcIP, ok = netip.AddrFromSlice(data[12:16]) + p.SrcIP, ok = netip.AddrFromSlice(buf[12:16]) if !ok { return fmt.Errorf("malformed IPv4 address") } - p.DstIP, ok = netip.AddrFromSlice(data[16:headerSizeIPv4]) + p.DstIP, ok = netip.AddrFromSlice(buf[16:headerSizeIPv4]) if !ok { return fmt.Errorf("malformed IPv4 address") } if p.IHL > 5 { offset := headerSizeIPv4 + ((p.IHL - 5) << 2) - p.Options = data[headerSizeIPv4:offset] - p.Payload = data[offset:] + p.Options = buf[headerSizeIPv4:offset] + p.Payload = buf[offset:] } else { - p.Payload = data[headerSizeIPv4:] + p.Payload = buf[headerSizeIPv4:] } return nil } diff --git a/layers/ipv6.go b/layers/ipv6.go index 8efaef2..9a61cc8 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -80,17 +80,19 @@ func (p *IPv6Packet) Parse(data []byte) error { if len(data) < headerSizeIPv6 { return fmt.Errorf("minimum header size for IPv6 is %d bytes, got %d bytes", headerSizeIPv6, len(data)) } - versionTrafficFlow := binary.BigEndian.Uint32(data[0:4]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + versionTrafficFlow := binary.BigEndian.Uint32(buf[0:4]) p.Version = uint8(versionTrafficFlow >> 28) p.TrafficClass = newTrafficiClass(uint8((versionTrafficFlow >> 20) & 0xFF)) p.FlowLabel = versionTrafficFlow & (1<<20 - 1) - p.PayloadLength = binary.BigEndian.Uint16(data[4:6]) - p.NextHeader = data[6] + p.PayloadLength = binary.BigEndian.Uint16(buf[4:6]) + p.NextHeader = buf[6] p.NextHeaderDesc = p.nextHeader() - p.HopLimit = data[7] - p.SrcIP, _ = netip.AddrFromSlice(data[8:24]) - p.DstIP, _ = netip.AddrFromSlice(data[24:headerSizeIPv6]) - p.payload = data[headerSizeIPv6:] + p.HopLimit = buf[7] + p.SrcIP, _ = netip.AddrFromSlice(buf[8:24]) + p.DstIP, _ = netip.AddrFromSlice(buf[24:headerSizeIPv6]) + p.payload = buf[headerSizeIPv6:] return nil } diff --git a/layers/snmp.go b/layers/snmp.go index 984de12..815d1b5 100644 --- a/layers/snmp.go +++ b/layers/snmp.go @@ -18,7 +18,9 @@ func (s *SNMPMessage) Summary() string { } func (s *SNMPMessage) Parse(data []byte) error { - s.Payload = data + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + s.Payload = buf return nil } diff --git a/layers/ssh.go b/layers/ssh.go index e128906..7afd74f 100644 --- a/layers/ssh.go +++ b/layers/ssh.go @@ -85,34 +85,36 @@ func (s *SSHMessage) Parse(data []byte) error { if len(data) < messageSizeSSH { return fmt.Errorf("minimum message size for SSH is %d bytes, got %d bytes", messageSizeSSH, len(data)) } + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) s.Protocol = "" s.Messages = nil - if bytes.HasSuffix(data, crlf) { - s.Protocol = bytesToStr(bytes.TrimSuffix(data, crlf)) + if bytes.HasSuffix(buf, crlf) { + s.Protocol = bytesToStr(bytes.TrimSuffix(buf, crlf)) return nil } s.Messages = make([]*Message, 0, 3) - for len(data) > 0 { + for len(buf) > 0 { m := &Message{} s.Messages = append(s.Messages, m) - plen := binary.BigEndian.Uint32(data[0:4]) + plen := binary.BigEndian.Uint32(buf[0:4]) if plen > 0xffff { - m.Payload = data + m.Payload = buf break } - m.MesssageType = data[5] + m.MesssageType = buf[5] if m.MesssageTypeDesc = mtypedesc(m.MesssageType); m.MesssageTypeDesc == "Unknown" { - m.Payload = data + m.Payload = buf break } m.PacketLength = plen - m.PaddingLength = data[4] + m.PaddingLength = buf[4] offset := int(messageSizeSSH + m.PacketLength - 2) - if offset <= len(data) { - m.Payload = data[messageSizeSSH:offset] - data = data[offset:] + if offset <= len(buf) { + m.Payload = buf[messageSizeSSH:offset] + buf = buf[offset:] } else { - m.Payload = data[messageSizeSSH:] + m.Payload = buf[messageSizeSSH:] break } } diff --git a/layers/tcp.go b/layers/tcp.go index 56658cc..5d70833 100644 --- a/layers/tcp.go +++ b/layers/tcp.go @@ -114,19 +114,21 @@ func (t *TCPSegment) Parse(data []byte) error { if len(data) < headerSizeTCP { return fmt.Errorf("minimum header size for TCP is %d bytes, got %d bytes", headerSizeTCP, len(data)) } - t.SrcPort = binary.BigEndian.Uint16(data[0:2]) - t.DstPort = binary.BigEndian.Uint16(data[2:4]) - t.SeqNumber = binary.BigEndian.Uint32(data[4:8]) - t.AckNumber = binary.BigEndian.Uint32(data[8:12]) - offsetReservedFlags := binary.BigEndian.Uint16(data[12:14]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + t.SrcPort = binary.BigEndian.Uint16(buf[0:2]) + t.DstPort = binary.BigEndian.Uint16(buf[2:4]) + t.SeqNumber = binary.BigEndian.Uint32(buf[4:8]) + t.AckNumber = binary.BigEndian.Uint32(buf[8:12]) + offsetReservedFlags := binary.BigEndian.Uint16(buf[12:14]) t.DataOffset = uint8(offsetReservedFlags >> 12) t.Reserved = uint8((offsetReservedFlags >> 8) & 15) t.Flags = newTCPFlags(uint8(offsetReservedFlags & (1<<8 - 1))) - t.WindowSize = binary.BigEndian.Uint16(data[14:16]) - t.Checksum = binary.BigEndian.Uint16(data[16:18]) - t.UrgentPointer = binary.BigEndian.Uint16(data[18:headerSizeTCP]) - t.Options = data[headerSizeTCP : t.DataOffset<<2] - t.payload = data[t.DataOffset<<2:] + t.WindowSize = binary.BigEndian.Uint16(buf[14:16]) + t.Checksum = binary.BigEndian.Uint16(buf[16:18]) + t.UrgentPointer = binary.BigEndian.Uint16(buf[18:headerSizeTCP]) + t.Options = buf[headerSizeTCP : t.DataOffset<<2] + t.payload = buf[t.DataOffset<<2:] return nil } diff --git a/layers/tls.go b/layers/tls.go index fff4e1f..ce46dc2 100644 --- a/layers/tls.go +++ b/layers/tls.go @@ -539,29 +539,31 @@ func (t *TLSMessage) printRecords() string { } func (t *TLSMessage) Parse(data []byte) error { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) t.Records = make([]*Record, 0, 5) - if len(data) < headerSizeTLS { + if len(buf) < headerSizeTLS { return ErrTLSTooShort } - for len(data) > 0 { - ctype := data[0] + for len(buf) > 0 { + ctype := buf[0] ctdesc := ctdesc(ctype) if ctdesc == "Unknown" { break } - if len(data) < 3 { + if len(buf) < 3 { break } - ver := binary.BigEndian.Uint16(data[1:3]) + ver := binary.BigEndian.Uint16(buf[1:3]) verdesc := verdesc(ver) if verdesc == "Unknown" { break } - if len(data) < headerSizeTLS { + if len(buf) < headerSizeTLS { break } - rlen := binary.BigEndian.Uint16(data[3:headerSizeTLS]) - rb := min(uint16(headerSizeTLS+rlen), uint16(len(data))) + rlen := binary.BigEndian.Uint16(buf[3:headerSizeTLS]) + rb := min(uint16(headerSizeTLS+rlen), uint16(len(buf))) if rb < headerSizeTLS { break } @@ -570,12 +572,12 @@ func (t *TLSMessage) Parse(data []byte) error { ContentTypeDesc: ctdesc, Version: &TLSVersion{Val: ver, Desc: verdesc}, Length: rlen, - Data: data[headerSizeTLS:rb], + Data: buf[headerSizeTLS:rb], } t.Records = append(t.Records, r) - data = data[rb:] + buf = buf[rb:] } - t.Data = data + t.Data = buf return nil } diff --git a/layers/udp.go b/layers/udp.go index 96eff98..ca6565c 100644 --- a/layers/udp.go +++ b/layers/udp.go @@ -71,11 +71,13 @@ func (u *UDPSegment) UnmarshalBinary(data []byte) error { if len(data) < headerSizeUDP { return fmt.Errorf("minimum header size for UDP is %d bytes, got %d bytes", headerSizeUDP, len(data)) } - u.SrcPort = binary.BigEndian.Uint16(data[0:2]) - u.DstPort = binary.BigEndian.Uint16(data[2:4]) - u.UDPLength = binary.BigEndian.Uint16(data[4:6]) - u.Checksum = binary.BigEndian.Uint16(data[6:headerSizeUDP]) - u.Payload = data[headerSizeUDP:] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + u.SrcPort = binary.BigEndian.Uint16(buf[0:2]) + u.DstPort = binary.BigEndian.Uint16(buf[2:4]) + u.UDPLength = binary.BigEndian.Uint16(buf[4:6]) + u.Checksum = binary.BigEndian.Uint16(buf[6:headerSizeUDP]) + u.Payload = buf[headerSizeUDP:] return nil } diff --git a/version.go b/version.go index 3ca5177..98fbc19 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.10" +const Version string = "mshark v0.0.11" From 72b3621d46adddc80f350cfb50ae44b30127675e Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Mon, 4 Aug 2025 09:51:45 +0300 Subject: [PATCH 03/17] fixed race condition when arpspoof shutdwon happened before start finished --- arpspoof/arpspoof.go | 33 ++++++++++++++++++++------------- version.go | 2 +- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/arpspoof/arpspoof.go b/arpspoof/arpspoof.go index 3a9de15..8c14c46 100644 --- a/arpspoof/arpspoof.go +++ b/arpspoof/arpspoof.go @@ -12,6 +12,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/malfunkt/iprange" @@ -156,19 +157,20 @@ func (at *ARPTable) Refresh() error { } type ARPSpoofer struct { - targets []netip.Addr - gwIP netip.Addr - gwMAC net.HardwareAddr - iface *net.Interface - hostIP netip.Addr - hostMAC net.HardwareAddr - fullduplex bool - arpTable *ARPTable - packets chan *Packet - logger *zerolog.Logger - quit chan bool - wg sync.WaitGroup - p *packet.Conn + targets []netip.Addr + gwIP netip.Addr + gwMAC net.HardwareAddr + iface *net.Interface + hostIP netip.Addr + hostMAC net.HardwareAddr + fullduplex bool + startingFlag atomic.Bool + arpTable *ARPTable + packets chan *Packet + logger *zerolog.Logger + quit chan bool + wg sync.WaitGroup + p *packet.Conn } func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { @@ -302,6 +304,7 @@ func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { } func (ar *ARPSpoofer) Start() { + ar.startingFlag.Store(true) ar.logger.Info().Msg("[arp spoofer] Started") go ar.handlePackets() ar.logger.Debug().Msgf("[arp spoofer] Probing %d targets", len(ar.targets)) @@ -312,6 +315,7 @@ func (ar *ARPSpoofer) Start() { go ar.probeTargets() go ar.refreshARPTable() ar.wg.Add(1) + ar.startingFlag.Store(false) for { select { case <-ar.quit: @@ -325,6 +329,9 @@ func (ar *ARPSpoofer) Start() { } func (ar *ARPSpoofer) Stop() error { + for ar.startingFlag.Load() { + time.Sleep(50 * time.Millisecond) + } var err error ar.logger.Info().Msg("[arp spoofer] Stopping...") close(ar.quit) diff --git a/version.go b/version.go index 98fbc19..2b1cb60 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.11" +const Version string = "mshark v0.0.12" From 127edd95a672718c05030fb4f8f6f4a75d630c4a Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 8 Aug 2025 11:27:03 +0300 Subject: [PATCH 04/17] added HTTPS parsing for DNS, fixed data race, updated network package --- layers/dns.go | 337 ++++++++++++++++++++++++++++++++++++++------- layers/layers.go | 57 ++++++-- mshark.go | 10 +- network/network.go | 32 +++++ version.go | 2 +- 5 files changed, 376 insertions(+), 62 deletions(-) diff --git a/layers/dns.go b/layers/dns.go index 3d9f9f0..74e90a5 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -2,6 +2,7 @@ package layers import ( "encoding/binary" + "encoding/hex" "fmt" "net/netip" "strings" @@ -206,7 +207,7 @@ func (d *DNSMessage) String() string { func (d *DNSMessage) Summary() string { var sb strings.Builder - sb.WriteString(fmt.Sprintf("DNS Message: %s %s %#04x ", d.Flags.OPCodeDesc, d.Flags.QRDesc, d.TransactionID)) + sb.WriteString(fmt.Sprintf("DNS Message: %s (%s) %#04x ", d.Flags.OPCodeDesc, d.Flags.QRDesc, d.TransactionID)) for _, rec := range d.Questions { sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) if sb.Len() > maxLenSummary { @@ -214,19 +215,19 @@ func (d *DNSMessage) Summary() string { } } for _, rec := range d.AnswerRRs { - sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) + sb.WriteString(rec.Summary()) if sb.Len() > maxLenSummary { goto result } } for _, rec := range d.AuthorityRRs { - sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) + sb.WriteString(rec.Summary()) if sb.Len() > maxLenSummary { goto result } } for _, rec := range d.AdditionalRRs { - sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) + sb.WriteString(rec.Summary()) if sb.Len() > maxLenSummary { goto result } @@ -236,8 +237,7 @@ result: return sb.String()[:maxLenSummary] + string(ellipsis) } -// Parse parses the given byte data into a DNSMessage struct. -func (d *DNSMessage) Parse(data []byte) error { +func (d *DNSMessage) UnmarshalBinary(data []byte) error { if len(data) < headerSizeDNS { return fmt.Errorf("minimum header size for DNS is %d bytes, got %d bytes", headerSizeDNS, len(data)) } @@ -250,26 +250,40 @@ func (d *DNSMessage) Parse(data []byte) error { d.NSCount = binary.BigEndian.Uint16(buf[8:10]) d.ARCount = binary.BigEndian.Uint16(buf[10:headerSizeDNS]) var tail []byte + var err error payload := buf[headerSizeDNS:] - d.Questions = nil - d.AnswerRRs = nil - d.AuthorityRRs = nil - d.AdditionalRRs = nil if d.QDCount > 0 { - d.Questions, tail = parseQueries(payload, payload, d.QDCount) + d.Questions, tail, err = parseQueries(payload, payload, d.QDCount) + if err != nil { + return fmt.Errorf("failed parsing queries: %v", err) + } } if d.ANCount > 0 { - d.AnswerRRs, tail = parseResourceRecords(payload, tail, d.ANCount) + d.AnswerRRs, tail, err = parseResourceRecords(payload, tail, d.ANCount) + if err != nil { + return fmt.Errorf("failed parsing answers: %v", err) + } } if d.NSCount > 0 { - d.AuthorityRRs, tail = parseResourceRecords(payload, tail, d.NSCount) + d.AuthorityRRs, tail, err = parseResourceRecords(payload, tail, d.NSCount) + if err != nil { + return fmt.Errorf("failed parsing authority records: %v", err) + } } if d.ARCount > 0 { - d.AdditionalRRs, _ = parseResourceRecords(payload, tail, d.ARCount) + d.AdditionalRRs, _, err = parseResourceRecords(payload, tail, d.ARCount) + if err != nil { + return fmt.Errorf("failed parsing additional records: %v", err) + } } return nil } +// Parse parses the given byte data into a DNSMessage struct. +func (d *DNSMessage) Parse(data []byte) error { + return d.UnmarshalBinary(data) +} + func (d *DNSMessage) NextLayer() (layer string, payload []byte) { return } func (d *DNSMessage) printRecords() string { @@ -298,7 +312,7 @@ func (d *DNSMessage) printRecords() string { sb.WriteString(rec.String()) } } - return sb.String() + return strings.TrimSuffix(sb.String(), "\n") } type RecordClass struct { @@ -412,6 +426,28 @@ func (rt *ResourceRecord) String() string { return record } +func (rt *ResourceRecord) Summary() string { + var summary string + switch rd := rt.RData.(type) { + case *RDataA: + case *RDataAAAA: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.Address) + case *RDataNS: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.NsdName) + case *RDataCNAME: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.CName) + case *RDataSOA: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.PrimaryNS) + case *RDataMX: + summary = fmt.Sprintf("%s %d %s ", rt.Type.Name, rd.Preference, rd.Exchange) + case *RDataTXT: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.TxtData) + default: + summary = fmt.Sprintf("%s ", rt.Type.Name) + } + return summary +} + type QueryEntry struct { Name string // Name of the node to which this record pertains. Type *RecordType // Type of RR in numeric form. @@ -524,12 +560,105 @@ func (d *RDataOPT) String() string { d.DataLen) } +type SvcParamKey struct { + Val uint16 + Desc string +} + +// https://www.iana.org/assignments/dns-svcb/dns-svcb.xhtml +func svcparamkeydesc(key uint16) string { + var svcdesc string + switch key { + case 0: + svcdesc = "mandatory" + case 1: + svcdesc = "alpn" + case 2: + svcdesc = "no-default-alpn" + case 3: + svcdesc = "port" + case 4: + svcdesc = "ipv4hint" + case 5: + svcdesc = "ech" + case 6: + svcdesc = "ipv6hint" + case 7: + svcdesc = "dohpath" + case 8: + svcdesc = "ohttp" + case 9: + svcdesc = "tls-supported-groups" + } + return svcdesc +} + +func newSvcParamKey(key uint16) *SvcParamKey { + return &SvcParamKey{Val: key, Desc: svcparamkeydesc(key)} +} + +func (spk *SvcParamKey) String() string { + return fmt.Sprintf("%s (%d)", spk.Desc, spk.Val) +} + +type SvcParam struct { + Key *SvcParamKey + Length uint16 + Value []byte // TODO: add proper parsing +} + +func newSvcParam(data []byte) (*SvcParam, []byte, error) { + if len(data) < 4 { + return nil, nil, ErrSliceBounds + } + key := newSvcParamKey(binary.BigEndian.Uint16(data[0:2])) + length := binary.BigEndian.Uint16(data[2:4]) + offset := 4 + length + if offset > uint16(len(data)) { + return nil, nil, ErrSliceBounds + } + value := data[4:offset] + return &SvcParam{Key: key, Length: length, Value: value}, data[offset:], nil +} + +func (sp *SvcParam) String() string { + return fmt.Sprintf(` - SvcParamKey: %s + - SvcParamValue length: %d + - SvcParamValue: %s +`, + sp.Key, + sp.Length, + hex.EncodeToString(sp.Value), + ) +} + type RDataHTTPS struct { - Data string // TODO: add proper parsing + SvcPriority uint16 + Length int + TargetName string + SvcParams []*SvcParam +} + +func (d *RDataHTTPS) printSvcParams() string { + var sb strings.Builder + for _, p := range d.SvcParams { + if p == nil { + continue + } + sb.WriteString(p.String()) + } + return strings.TrimRight(sb.String(), "\n") } func (d *RDataHTTPS) String() string { - return d.Data + return fmt.Sprintf(`SvcPriority: %d + - TargetName: %s + - SvcParams: +%s`, + d.SvcPriority, + d.TargetName, + d.printSvcParams(), + ) } type RDataUnknown struct { @@ -543,15 +672,24 @@ func (d *RDataUnknown) String() string { // extractDomain extracts the DNS domain name from the given payload and tail. // // The domain name is parsed according to RFC 1035 section 4.1. -func extractDomain(payload, tail []byte) (string, []byte) { +func extractDomain(payload, tail []byte) (string, []byte, error) { // see https://brunoscheufler.com/blog/2024-05-12-building-a-dns-message-parser#domain-names var domainName string for { blen := tail[0] if blen>>6 == 0b11 { + if len(tail) < 2 { + return "", nil, ErrSliceBounds + } // compressed message offset is 14 bits according to RFC 1035 section 4.1.4 offset := binary.BigEndian.Uint16(tail[0:2])&(1<<14-1) - headerSizeDNS - part, _ := extractDomain(payload, payload[offset:]) // TODO: iterative approach + if offset > uint16(len(payload)) { + return "", nil, ErrSliceBounds + } + part, _, err := extractDomain(payload, payload[offset:]) // TODO: iterative approach + if err != nil { + return "", nil, err + } domainName += part tail = tail[2:] break @@ -560,17 +698,27 @@ func extractDomain(payload, tail []byte) (string, []byte) { if blen == 0 { break } + if int(blen) > len(tail) { + return "", nil, ErrSliceBounds + } domainName += bytesToStr(tail[0:blen]) domainName += "." tail = tail[blen:] } - return strings.TrimRight(domainName, "."), tail + return strings.TrimRight(domainName, "."), tail, nil } -func parseQuery(payload, tail []byte) (*QueryEntry, []byte) { +func parseQuery(payload, tail []byte) (*QueryEntry, []byte, error) { var domain string - domain, tail = extractDomain(payload, tail) + var err error + domain, tail, err = extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } + if len(tail) < 4 { + return nil, nil, ErrSliceBounds + } typ := binary.BigEndian.Uint16(tail[0:2]) cls := binary.BigEndian.Uint16(tail[2:4]) tail = tail[4:] @@ -578,43 +726,69 @@ func parseQuery(payload, tail []byte) (*QueryEntry, []byte) { Name: domain, Type: newRecordType(typ), Class: newRecordClass(cls), - }, tail + }, tail, nil } -func parseQueries(payload, tail []byte, numRecords uint16) ([]*QueryEntry, []byte) { +func parseQueries(payload, tail []byte, numRecords uint16) ([]*QueryEntry, []byte, error) { queries := make([]*QueryEntry, numRecords) + var err error for i := range queries { - queries[i], tail = parseQuery(payload, tail) + queries[i], tail, err = parseQuery(payload, tail) + if err != nil { + return nil, nil, err + } } - return queries, tail + return queries, tail, nil } // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4 -func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte) { +func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte, error) { var rdata fmt.Stringer + if rdl > len(tail) { + return nil, nil, ErrSliceBounds + } switch typ { case 1: - addr, _ := netip.AddrFromSlice(tail[0:rdl]) + addr, ok := netip.AddrFromSlice(tail[0:rdl]) + if !ok { + return nil, nil, ErrParsingAddress + } rdata = &RDataA{Address: addr} case 2: - domain, _ := extractDomain(payload, tail) + domain, _, err := extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } rdata = &RDataNS{NsdName: domain} case 5: - domain, _ := extractDomain(payload, tail) + domain, _, err := extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } rdata = &RDataCNAME{CName: domain} case 6: var ( primary string mailbox string + err error ) ttail := tail - primary, ttail = extractDomain(payload, ttail) - mailbox, ttail = extractDomain(payload, ttail) + primary, ttail, err = extractDomain(payload, ttail) + if err != nil { + return nil, nil, err + } + mailbox, ttail, err = extractDomain(payload, ttail) + if err != nil { + return nil, nil, err + } + if len(ttail) < 20 { + return nil, nil, ErrSliceBounds + } serial := binary.BigEndian.Uint32(ttail[0:4]) refresh := binary.BigEndian.Uint32(ttail[4:8]) retry := binary.BigEndian.Uint32(ttail[8:12]) expire := binary.BigEndian.Uint32(ttail[12:16]) - min := binary.BigEndian.Uint32(ttail[16:20]) + minttl := binary.BigEndian.Uint32(ttail[16:20]) rdata = &RDataSOA{ PrimaryNS: primary, RespAuthorityMailbox: mailbox, @@ -622,11 +796,14 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte RefreshInterval: refresh, RetryInterval: retry, ExpireLimit: expire, - MinimumTTL: min, + MinimumTTL: minttl, } case 15: preference := binary.BigEndian.Uint16(tail[0:2]) - domain, _ := extractDomain(payload, tail[2:rdl]) + domain, _, err := extractDomain(payload, tail[2:rdl]) + if err != nil { + return nil, nil, err + } rdata = &RDataMX{ Preference: preference, Exchange: domain, @@ -634,9 +811,15 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte case 16: rdata = &RDataTXT{TxtData: string(tail[:rdl])} case 28: - addr, _ := netip.AddrFromSlice(tail[0:rdl]) + addr, ok := netip.AddrFromSlice(tail[0:rdl]) + if !ok { + return nil, nil, ErrParsingAddress + } rdata = &RDataAAAA{Address: addr} case 41: + if len(tail) < 8 { + return nil, nil, ErrSliceBounds + } ups := binary.BigEndian.Uint16(tail[0:2]) hb := tail[2] ednsv := tail[3] @@ -650,35 +833,81 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte DataLen: uint16(rdl), } case 65: - rdata = &RDataHTTPS{Data: string(tail[:rdl])} + priority := binary.BigEndian.Uint16(tail[0:2]) + nameLength := tail[2] + var target string + var err error + ttail := tail[:rdl] + if nameLength == 0 { + target = "Root" + ttail = ttail[3:] + } else { + target, ttail, err = extractDomain(payload, ttail) + if err != nil { + return nil, nil, err + } + } + svcParams := make([]*SvcParam, 10) + var svcParam *SvcParam + for len(ttail) > 0 { + svcParam, ttail, err = newSvcParam(ttail) + if err != nil { + return nil, nil, err + } + if svcParam.Key.Desc == "Unknown" { + continue + } + svcParams[svcParam.Key.Val] = svcParam + } + rdata = &RDataHTTPS{SvcPriority: priority, Length: int(nameLength), TargetName: target, SvcParams: svcParams} default: rdata = &RDataUnknown{Data: string(tail[:rdl])} } - return rdata, tail[rdl:] + if rdl > len(tail) { + return nil, nil, ErrSliceBounds + } + return rdata, tail[rdl:], nil } -func parseRoot(payload, tail []byte) (*ResourceRecord, []byte) { +func parseRoot(payload, tail []byte) (*ResourceRecord, []byte, error) { + if len(tail) < 10 { + return nil, nil, ErrSliceBounds + } typ := binary.BigEndian.Uint16(tail[0:2]) rdl := int(binary.BigEndian.Uint16(tail[8:10])) var rdata fmt.Stringer - rdata, tail = parseRData(payload, tail[2:], typ, rdl) + var err error + rdata, tail, err = parseRData(payload, tail[2:], typ, rdl) + if err != nil { + return nil, nil, err + } return &ResourceRecord{ Name: "Root", Type: newRecordType(typ), Class: &RecordClass{}, RData: rdata, - }, tail + }, tail, nil } -func parseResourceRecord(payload, tail []byte) (*ResourceRecord, []byte) { +func parseResourceRecord(payload, tail []byte) (*ResourceRecord, []byte, error) { var domain string - domain, tail = extractDomain(payload, tail) + var err error + domain, tail, err = extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } + if len(tail) < 10 { + return nil, nil, ErrSliceBounds + } typ := binary.BigEndian.Uint16(tail[0:2]) cls := binary.BigEndian.Uint16(tail[2:4]) ttl := binary.BigEndian.Uint32(tail[4:8]) rdl := binary.BigEndian.Uint16(tail[8:10]) var rdata fmt.Stringer - rdata, tail = parseRData(payload, tail[10:], typ, int(rdl)) + rdata, tail, err = parseRData(payload, tail[10:], typ, int(rdl)) + if err != nil { + return nil, nil, err + } return &ResourceRecord{ Name: domain, Type: newRecordType(typ), @@ -686,17 +915,27 @@ func parseResourceRecord(payload, tail []byte) (*ResourceRecord, []byte) { TTL: ttl, RDLength: rdl, RData: rdata, - }, tail + }, tail, nil } -func parseResourceRecords(payload, tail []byte, numRecords uint16) ([]*ResourceRecord, []byte) { +func parseResourceRecords(payload, tail []byte, numRecords uint16) ([]*ResourceRecord, []byte, error) { + if len(tail) < 1 { + return nil, nil, ErrSliceBounds + } records := make([]*ResourceRecord, numRecords) + var err error for i := range records { if tail[0] != 0 { - records[i], tail = parseResourceRecord(payload, tail) + records[i], tail, err = parseResourceRecord(payload, tail) + if err != nil { + return nil, nil, err + } } else { - records[i], tail = parseRoot(payload, tail[1:]) + records[i], tail, err = parseRoot(payload, tail[1:]) + if err != nil { + return nil, nil, err + } } } - return records, tail + return records, tail, nil } diff --git a/layers/layers.go b/layers/layers.go index 46c526a..aacf702 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -6,7 +6,7 @@ import ( "unsafe" ) -const maxLenSummary = 100 +const maxLenSummary = 110 var LayerMap = map[string]Layer{ "ETH": &EthernetFrame{}, @@ -26,15 +26,17 @@ var LayerMap = map[string]Layer{ } var ( - bspace = []byte(" ") - dash = []byte("- ") - lfd = []byte("\n- ") - slfd = "\n- " - lf = []byte("\n") - crlf = []byte("\r\n") - dcrlf = []byte("\r\n\r\n") - ellipsis = []byte("...") - contdata = []byte("Continuation data") + bspace = []byte(" ") + dash = []byte("- ") + lfd = []byte("\n- ") + slfd = "\n- " + lf = []byte("\n") + crlf = []byte("\r\n") + dcrlf = []byte("\r\n\r\n") + ellipsis = []byte("...") + contdata = []byte("Continuation data") + ErrParsingAddress = fmt.Errorf("failed parsing IP address") + ErrSliceBounds = fmt.Errorf("slice bounds out of range") ) type Layer interface { @@ -44,6 +46,41 @@ type Layer interface { Summary() string } +func GetNextLayer(layer string) Layer { + switch layer { + case "ETH": + return &EthernetFrame{} + case "IPv4": + return &IPv4Packet{} + case "IPv6": + return &IPv6Packet{} + case "ARP": + return &ARPPacket{} + case "TCP": + return &TCPSegment{} + case "UDP": + return &UDPSegment{} + case "ICMP": + return &ICMPSegment{} + case "ICMPv6": + return &ICMPv6Segment{} + case "DNS": + return &DNSMessage{} + case "FTP": + return &FTPMessage{} + case "HTTP": + return &HTTPMessage{} + case "SNMP": + return &SNMPMessage{} + case "SSH": + return &SSHMessage{} + case "TLS": + return &TLSMessage{} + default: + return nil + } +} + func bytesToStr(b []byte) string { return unsafe.String(unsafe.SliceData(b), len(b)) } diff --git a/mshark.go b/mshark.go index 67b9eee..f905ee9 100644 --- a/mshark.go +++ b/mshark.go @@ -82,7 +82,10 @@ func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { mw.packets++ fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05-0700")) fmt.Fprintln(mw.w, "==================================================================") - next := layers.LayerMap["ETH"] + next := layers.GetNextLayer("ETH") + if next == nil { + return nil + } if err := next.Parse(data); err != nil { return err } @@ -93,7 +96,10 @@ func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { if name == "" || data == nil || len(data) == 0 { return nil } - next = layers.LayerMap[name] + next = layers.GetNextLayer(name) + if next == nil { + return nil + } if err := next.Parse(data); err != nil { return err } diff --git a/network/network.go b/network/network.go index 091bef6..6213112 100644 --- a/network/network.go +++ b/network/network.go @@ -161,3 +161,35 @@ func IsLocalAddress(addr string) bool { host = strings.ToLower(host) return strings.HasSuffix(host, ".local") || host == "localhost" } + +// AddrEqual compares two address strings and returns true if they are equal. +// +// It treats loopback and unspecified IPs as equivalent. Returns false in case of inequality or error. +func AddrEqual(a, b string) bool { + if a == "" || b == "" { + return false + } + addr1, err := netip.ParseAddrPort(a) + if err != nil { + return false + } + addr2, err := netip.ParseAddrPort(b) + if err != nil { + return false + } + if addr1.Addr().IsLoopback() { + if addr1.Addr().Is4In6() || addr1.Addr().Is4() { + addr1 = netip.AddrPortFrom(netip.IPv4Unspecified(), addr1.Port()) + } else { + addr1 = netip.AddrPortFrom(netip.IPv6Unspecified(), addr1.Port()) + } + } + if addr2.Addr().IsLoopback() { + if addr2.Addr().Is4In6() || addr2.Addr().Is4() { + addr2 = netip.AddrPortFrom(netip.IPv4Unspecified(), addr2.Port()) + } else { + addr2 = netip.AddrPortFrom(netip.IPv6Unspecified(), addr2.Port()) + } + } + return addr1.Compare(addr2) == 0 +} diff --git a/version.go b/version.go index 2b1cb60..8db832f 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.12" +const Version string = "mshark v0.0.13" From 8c00df195094eb17a24a4b77142358aafbffb0c6 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 8 Aug 2025 12:02:47 +0300 Subject: [PATCH 05/17] cosmetic changes to mshark --- layers/dns.go | 2 ++ layers/layers.go | 1 + mshark.go | 8 +++++--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/layers/dns.go b/layers/dns.go index 74e90a5..cd0e35a 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -589,6 +589,8 @@ func svcparamkeydesc(key uint16) string { svcdesc = "ohttp" case 9: svcdesc = "tls-supported-groups" + default: + svcdesc = "Unknown" } return svcdesc } diff --git a/layers/layers.go b/layers/layers.go index aacf702..392db66 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -47,6 +47,7 @@ type Layer interface { } func GetNextLayer(layer string) Layer { + // TODO (shadowy-pycoder): add this to NextLayer, choose by ports, parse, use fallback on error switch layer { case "ETH": return &EthernetFrame{} diff --git a/mshark.go b/mshark.go index f905ee9..1d1f413 100644 --- a/mshark.go +++ b/mshark.go @@ -8,6 +8,7 @@ import ( "net" "os" "os/signal" + "strings" "time" "github.com/mdlayher/packet" @@ -18,13 +19,14 @@ import ( const unixEthPAll int = 0x03 -var colorMap = map[int]string{ +var colorMap = map[int]string{ // TODO (shadowy-pycoder): add colors from shadowy-pycoder/colors 0: "\033[37m", 1: "\033[36m", 2: "\033[32m", 3: "\033[33m", 4: "\033[35m", } +var packetDelimeter = "\033[37m" + strings.Repeat("─", 66) + "\033[0m" var _ PacketWriter = &Writer{} @@ -80,8 +82,8 @@ func (mw *Writer) printPacket(layer layers.Layer, layerNum int) { // Timestamps are to be generated by the calling code. func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { mw.packets++ - fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05-0700")) - fmt.Fprintln(mw.w, "==================================================================") + fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05.000000-0700")) + fmt.Fprintln(mw.w, packetDelimeter) next := layers.GetNextLayer("ETH") if next == nil { return nil From 0ba2de9e9be02b908c956a27f67f220f8262ba5e Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 8 Aug 2025 19:27:52 +0300 Subject: [PATCH 06/17] added ParseNextLayer function --- layers/dns.go | 5 ++ layers/layers.go | 156 ++++++++++++++++++++++++++++++++++++++++++----- layers/ssh.go | 4 +- 3 files changed, 148 insertions(+), 17 deletions(-) diff --git a/layers/dns.go b/layers/dns.go index cd0e35a..419d3f6 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -8,6 +8,8 @@ import ( "strings" ) +// TODO (shadowy-pycoder): add MarshalJSON + const headerSizeDNS = 12 type DNSFlags struct { @@ -835,6 +837,9 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte DataLen: uint16(rdl), } case 65: + if len(tail) < 3 { + return nil, nil, ErrSliceBounds + } priority := binary.BigEndian.Uint16(tail[0:2]) nameLength := tail[2] var target string diff --git a/layers/layers.go b/layers/layers.go index 392db66..2122ffe 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -2,27 +2,30 @@ package layers import ( + "bytes" + "encoding/binary" "fmt" + "net/netip" "unsafe" ) const maxLenSummary = 110 -var LayerMap = map[string]Layer{ - "ETH": &EthernetFrame{}, - "IPv4": &IPv4Packet{}, - "IPv6": &IPv6Packet{}, - "ARP": &ARPPacket{}, - "TCP": &TCPSegment{}, - "UDP": &UDPSegment{}, - "ICMP": &ICMPSegment{}, - "ICMPv6": &ICMPv6Segment{}, - "DNS": &DNSMessage{}, - "FTP": &FTPMessage{}, - "HTTP": &HTTPMessage{}, - "SNMP": &SNMPMessage{}, - "SSH": &SSHMessage{}, - "TLS": &TLSMessage{}, +var Layers = []string{ + "ETH", + "IPv4", + "IPv6", + "ARP", + "TCP", + "UDP", + "ICMP", + "ICMPv6", + "DNS", + "FTP", + "HTTP", + "SNMP", + "SSH", + "TLS", } var ( @@ -46,6 +49,129 @@ type Layer interface { Summary() string } +func parseNextLayerFallback(data []byte) Layer { + for _, layer := range Layers { + next := GetNextLayer(layer) + if err := next.Parse(data); err == nil { + return next + } + } + return nil +} + +func parseNextLayerFromBytes(data []byte) Layer { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + var next Layer + firstByte := buf[0] + if firstByte >= 0x45 && firstByte <= 0x4F { + next = GetNextLayer("IPv4") + if err := next.Parse(buf); err == nil { + return next + } + } + if firstByte>>4 == 6 { + next = GetNextLayer("IPv6") + if err := next.Parse(buf); err == nil { + return next + } + } + if firstByte == HandshakeTLSVal { + next = GetNextLayer("TLS") + if err := next.Parse(buf); err == nil { + return next + } + } + if len(buf) > 3 { + b1 := binary.BigEndian.Uint16(buf[0:2]) + b2 := binary.BigEndian.Uint16(buf[2:4]) + if b1 == 1 && (b2 == 0x0800 || b2 == 0x86dd) { + next = GetNextLayer("ARP") + if err := next.Parse(buf); err == nil { + return next + } + } + } + if len(buf) > 15 { + b1 := binary.BigEndian.Uint16(buf[12:14]) + if b1 == 0x0806 || b1 == 0x0800 || b1 == 0x86dd { + next = GetNextLayer("ETH") + if err := next.Parse(buf); err == nil { + return next + } + } + } + if bytes.Contains(buf, protohttp10) || bytes.Contains(buf, protohttp11) { + next = GetNextLayer("HTTP") + if err := next.Parse(buf); err == nil { + return next + } + } + if bytes.Contains(buf, protoSSH) { + next = GetNextLayer("SSH") + if err := next.Parse(buf); err == nil { + return next + } + } + return nil +} + +func addrMatch(src, dst *netip.AddrPort, ports []uint16) bool { + var srcPort, dstPort uint16 + if src != nil { + srcPort = src.Port() + } + if dst != nil { + dstPort = src.Port() + } + for _, port := range ports { + if srcPort == port || dstPort == port { + return true + } + } + return false +} + +func parseNextLayerFromAddress(data []byte, src, dst *netip.AddrPort) Layer { + var next Layer + switch { + case addrMatch(src, dst, []uint16{53, 5353, 853, 5355}): + next = GetNextLayer("DNS") + case addrMatch(src, dst, []uint16{80, 8080, 8000, 8888, 81, 591, 5911}): + next = GetNextLayer("HTTP") + case addrMatch(src, dst, []uint16{161, 162, 10161, 10162, 1161, 2161}): + next = GetNextLayer("SNMP") + case addrMatch(src, dst, []uint16{21, 20, 2121, 8021}): + next = GetNextLayer("FTP") + case addrMatch(src, dst, []uint16{22, 2222, 2200, 222, 2022}): + next = GetNextLayer("SSH") + case addrMatch(src, dst, []uint16{443, 465, 993, 995, 8443, 9443, 10443, 8444}): + next = GetNextLayer("TLS") + default: + return nil + } + if err := next.Parse(data); err == nil { + return next + } else { + return nil + } +} + +func ParseNextLayer(data []byte, src, dst *netip.AddrPort) Layer { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + var next Layer + if src != nil || dst != nil { + if next = parseNextLayerFromAddress(buf, src, dst); next != nil { + return next + } + } + if next = parseNextLayerFromBytes(buf); next != nil { + return next + } + return parseNextLayerFallback(buf) +} + func GetNextLayer(layer string) Layer { // TODO (shadowy-pycoder): add this to NextLayer, choose by ports, parse, use fallback on error switch layer { diff --git a/layers/ssh.go b/layers/ssh.go index 7afd74f..742d46c 100644 --- a/layers/ssh.go +++ b/layers/ssh.go @@ -9,6 +9,8 @@ import ( const messageSizeSSH = 6 +var protoSSH = []byte("SSH-") + type Message struct { PacketLength uint32 PaddingLength uint8 @@ -87,8 +89,6 @@ func (s *SSHMessage) Parse(data []byte) error { } buf := make([]byte, 0, len(data)) buf = append(buf, data...) - s.Protocol = "" - s.Messages = nil if bytes.HasSuffix(buf, crlf) { s.Protocol = bytesToStr(bytes.TrimSuffix(buf, crlf)) return nil From 49bfa618c7f92477d8e379d2eea519881e238a64 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 8 Aug 2025 20:24:29 +0300 Subject: [PATCH 07/17] changed NextLayer method --- layers/arp.go | 3 ++- layers/dns.go | 4 +++- layers/ethernet.go | 12 +++++++++--- layers/ftp.go | 3 ++- layers/http.go | 3 ++- layers/icmp.go | 4 +++- layers/icmpv6.go | 3 ++- layers/ipv4.go | 13 ++++++++----- layers/ipv6.go | 21 ++++++++++++++++----- layers/ipv6_test.go | 8 +++++--- layers/layers.go | 24 ++++++++++++++++-------- layers/snmp.go | 3 ++- layers/ssh.go | 3 ++- layers/tcp.go | 33 +++++++-------------------------- layers/tcp_test.go | 2 +- layers/tls.go | 3 ++- layers/udp.go | 6 ++++-- mshark.go | 9 +-------- 18 files changed, 87 insertions(+), 70 deletions(-) diff --git a/layers/arp.go b/layers/arp.go index 0532662..9feab7a 100644 --- a/layers/arp.go +++ b/layers/arp.go @@ -187,7 +187,8 @@ func (ap *ARPPacket) Parse(data []byte) error { return ap.UnmarshalBinary(data) } -func (ap *ARPPacket) NextLayer() (layer string, payload []byte) { return } +func (ap *ARPPacket) NextLayer() Layer { return nil } +func (ap *ARPPacket) Name() string { return "ARP" } func ptypedesc(pt uint16) string { var proto string diff --git a/layers/dns.go b/layers/dns.go index 419d3f6..b9e3eda 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -286,7 +286,9 @@ func (d *DNSMessage) Parse(data []byte) error { return d.UnmarshalBinary(data) } -func (d *DNSMessage) NextLayer() (layer string, payload []byte) { return } +func (d *DNSMessage) NextLayer() Layer { return nil } + +func (d *DNSMessage) Name() string { return "DNS" } func (d *DNSMessage) printRecords() string { var sb strings.Builder diff --git a/layers/ethernet.go b/layers/ethernet.go index cff433a..cbd3742 100644 --- a/layers/ethernet.go +++ b/layers/ethernet.go @@ -110,11 +110,17 @@ func (ef *EthernetFrame) Parse(data []byte) error { return ef.UnmarshalBinary(data) } -// NextLayer returns the name and payload of the next layer protocol based on the EtherType field of the EthernetFrame. -func (ef *EthernetFrame) NextLayer() (string, []byte) { - return ef.EtherType.Desc, ef.Payload +func (ef *EthernetFrame) NextLayer() Layer { + if next := GetNextLayer(ef.EtherType.Desc); next != nil { + if err := next.Parse(ef.Payload); err == nil { + return next + } + } + return ParseNextLayer(ef.Payload, nil, nil) } +func (ef *EthernetFrame) Name() string { return "ETH" } + func ethertypedesc(et EtherType) string { var etdesc string switch et { diff --git a/layers/ftp.go b/layers/ftp.go index 368c34c..680ec8c 100644 --- a/layers/ftp.go +++ b/layers/ftp.go @@ -41,4 +41,5 @@ func (f *FTPMessage) Parse(data []byte) error { return nil } -func (f *FTPMessage) NextLayer() (layer string, payload []byte) { return } +func (f *FTPMessage) NextLayer() Layer { return nil } +func (f *FTPMessage) Name() string { return "FTP" } diff --git a/layers/http.go b/layers/http.go index 2e67c31..fc426b8 100644 --- a/layers/http.go +++ b/layers/http.go @@ -86,7 +86,8 @@ func (h *HTTPMessage) Parse(data []byte) error { return nil } -func (h *HTTPMessage) NextLayer() (layer string, payload []byte) { return } +func (h *HTTPMessage) NextLayer() Layer { return nil } +func (h *HTTPMessage) Name() string { return "HTTP" } type HTTPRequestWrapper struct { Request HTTPRequest `json:"http_request"` diff --git a/layers/icmp.go b/layers/icmp.go index 40231bc..4fe7b02 100644 --- a/layers/icmp.go +++ b/layers/icmp.go @@ -67,7 +67,9 @@ func (i *ICMPSegment) Parse(data []byte) error { i.TypeDesc, i.CodeDesc = i.typecode() return nil } -func (i *ICMPSegment) NextLayer() (layer string, payload []byte) { return } + +func (i *ICMPSegment) NextLayer() Layer { return nil } +func (i *ICMPSegment) Name() string { return "ICMP" } func (i *ICMPSegment) typecode() (string, string) { // https://en.wikipedia.org/wiki/Internet_Control_Message_Protocol diff --git a/layers/icmpv6.go b/layers/icmpv6.go index c5b4a3a..2f3cdb6 100644 --- a/layers/icmpv6.go +++ b/layers/icmpv6.go @@ -70,7 +70,8 @@ func (i *ICMPv6Segment) Parse(data []byte) error { return nil } -func (i *ICMPv6Segment) NextLayer() (layer string, payload []byte) { return } +func (i *ICMPv6Segment) NextLayer() Layer { return nil } +func (i *ICMPv6Segment) Name() string { return "ICMPv6" } func (i *ICMPv6Segment) typecode() (string, string) { // https://en.wikipedia.org/wiki/ICMPv6 diff --git a/layers/ipv4.go b/layers/ipv4.go index bada8f0..4d9caa5 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -214,14 +214,17 @@ func protodesc(proto IPProto) string { return protodesc } -func (p *IPv4Packet) NextLayer() (string, []byte) { - layer := p.Protocol.Desc - if layer == "Unknown" { - layer = "" +func (p *IPv4Packet) NextLayer() Layer { + if next := GetNextLayer(p.Protocol.Desc); next != nil { + if err := next.Parse(p.Payload); err == nil { + return next + } } - return layer, p.Payload + return ParseNextLayer(p.Payload, nil, nil) } +func (p *IPv4Packet) Name() string { return "IPv4" } + func dscpdesc(dscp uint8) string { // https://en.wikipedia.org/wiki/Differentiated_services var dscpdesc string diff --git a/layers/ipv6.go b/layers/ipv6.go index 9a61cc8..b7800f5 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -44,7 +44,7 @@ type IPv6Packet struct { HopLimit uint8 SrcIP netip.Addr // The unicast IPv6 address of the sending node. DstIP netip.Addr // The IPv6 unicast or multicast address of the destination node(s). - payload []byte + Payload []byte } func (p *IPv6Packet) String() string { @@ -67,7 +67,7 @@ func (p *IPv6Packet) String() string { p.HopLimit, p.SrcIP, p.DstIP, - len(p.payload), + len(p.Payload), ) } @@ -92,11 +92,11 @@ func (p *IPv6Packet) Parse(data []byte) error { p.HopLimit = buf[7] p.SrcIP, _ = netip.AddrFromSlice(buf[8:24]) p.DstIP, _ = netip.AddrFromSlice(buf[24:headerSizeIPv6]) - p.payload = buf[headerSizeIPv6:] + p.Payload = buf[headerSizeIPv6:] return nil } -func (p *IPv6Packet) NextLayer() (string, []byte) { +func (p *IPv6Packet) nextLayer() string { // https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers var layer string switch p.NextHeader { @@ -109,9 +109,20 @@ func (p *IPv6Packet) NextLayer() (string, []byte) { default: layer = "" } - return layer, p.payload + return layer } +func (p *IPv6Packet) NextLayer() Layer { + if next := GetNextLayer(p.nextLayer()); next != nil { + if err := next.Parse(p.Payload); err == nil { + return next + } + } + return ParseNextLayer(p.Payload, nil, nil) +} + +func (p *IPv6Packet) Name() string { return "IPv6" } + func (p *IPv6Packet) nextHeader() string { // https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers var header string diff --git a/layers/ipv6_test.go b/layers/ipv6_test.go index 5cda13d..5f917bb 100644 --- a/layers/ipv6_test.go +++ b/layers/ipv6_test.go @@ -31,11 +31,13 @@ func TestParseIPv6(t *testing.T) { HopLimit: 64, SrcIP: netip.AddrFrom16([16]byte{ 0xFD, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xF2, 0x3F, 0xD1, 0x59, 0x50, 0x48, 0x9C, 0x14}), + 0xF2, 0x3F, 0xD1, 0x59, 0x50, 0x48, 0x9C, 0x14, + }), DstIP: netip.AddrFrom16([16]byte{ 0x26, 0x20, 0x00, 0x2D, 0x40, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2B}), - payload: []byte{}, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2B, + }), + Payload: []byte{}, } ip := &IPv6Packet{} packet, close := testPacket(t, "ipv6") diff --git a/layers/layers.go b/layers/layers.go index 2122ffe..14700fb 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -5,7 +5,6 @@ import ( "bytes" "encoding/binary" "fmt" - "net/netip" "unsafe" ) @@ -45,11 +44,15 @@ var ( type Layer interface { fmt.Stringer Parse(data []byte) error - NextLayer() (layer string, payload []byte) + NextLayer() Layer Summary() string + Name() string } func parseNextLayerFallback(data []byte) Layer { + if len(data) == 0 { + return nil + } for _, layer := range Layers { next := GetNextLayer(layer) if err := next.Parse(data); err == nil { @@ -60,6 +63,9 @@ func parseNextLayerFallback(data []byte) Layer { } func parseNextLayerFromBytes(data []byte) Layer { + if len(data) == 0 { + return nil + } buf := make([]byte, 0, len(data)) buf = append(buf, data...) var next Layer @@ -116,13 +122,13 @@ func parseNextLayerFromBytes(data []byte) Layer { return nil } -func addrMatch(src, dst *netip.AddrPort, ports []uint16) bool { +func addrMatch(src, dst *uint16, ports []uint16) bool { var srcPort, dstPort uint16 if src != nil { - srcPort = src.Port() + srcPort = *src } if dst != nil { - dstPort = src.Port() + dstPort = *dst } for _, port := range ports { if srcPort == port || dstPort == port { @@ -132,7 +138,10 @@ func addrMatch(src, dst *netip.AddrPort, ports []uint16) bool { return false } -func parseNextLayerFromAddress(data []byte, src, dst *netip.AddrPort) Layer { +func parseNextLayerFromAddress(data []byte, src, dst *uint16) Layer { + if len(data) == 0 { + return nil + } var next Layer switch { case addrMatch(src, dst, []uint16{53, 5353, 853, 5355}): @@ -157,7 +166,7 @@ func parseNextLayerFromAddress(data []byte, src, dst *netip.AddrPort) Layer { } } -func ParseNextLayer(data []byte, src, dst *netip.AddrPort) Layer { +func ParseNextLayer(data []byte, src, dst *uint16) Layer { buf := make([]byte, 0, len(data)) buf = append(buf, data...) var next Layer @@ -173,7 +182,6 @@ func ParseNextLayer(data []byte, src, dst *netip.AddrPort) Layer { } func GetNextLayer(layer string) Layer { - // TODO (shadowy-pycoder): add this to NextLayer, choose by ports, parse, use fallback on error switch layer { case "ETH": return &EthernetFrame{} diff --git a/layers/snmp.go b/layers/snmp.go index 815d1b5..21458e0 100644 --- a/layers/snmp.go +++ b/layers/snmp.go @@ -24,4 +24,5 @@ func (s *SNMPMessage) Parse(data []byte) error { return nil } -func (s *SNMPMessage) NextLayer() (layer string, payload []byte) { return } +func (s *SNMPMessage) NextLayer() Layer { return nil } +func (s *SNMPMessage) Name() string { return "SNMP" } diff --git a/layers/ssh.go b/layers/ssh.go index 742d46c..c007f78 100644 --- a/layers/ssh.go +++ b/layers/ssh.go @@ -121,7 +121,8 @@ func (s *SSHMessage) Parse(data []byte) error { return nil } -func (s *SSHMessage) NextLayer() (layer string, payload []byte) { return } +func (s *SSHMessage) NextLayer() Layer { return nil } +func (s *SSHMessage) Name() string { return "SSH" } // https://www.iana.org/assignments/ssh-parameters/ssh-parameters.xhtml func mtypedesc(mtype uint8) string { diff --git a/layers/tcp.go b/layers/tcp.go index 5d70833..18a6267 100644 --- a/layers/tcp.go +++ b/layers/tcp.go @@ -62,7 +62,7 @@ type TCPSegment struct { // indicating the last urgent data byte. UrgentPointer uint16 Options []byte // The length of this field is determined by the data offset field. - payload []byte + Payload []byte } func (t *TCPSegment) String() string { @@ -94,7 +94,7 @@ func (t *TCPSegment) String() string { t.UrgentPointer, len(t.Options), t.Options, - len(t.payload), + len(t.Payload), ) } @@ -105,7 +105,7 @@ func (t *TCPSegment) Summary() string { t.DstPort, t.Flags, t.WindowSize, - len(t.payload), + len(t.Payload), ) } @@ -128,31 +128,12 @@ func (t *TCPSegment) Parse(data []byte) error { t.Checksum = binary.BigEndian.Uint16(buf[16:18]) t.UrgentPointer = binary.BigEndian.Uint16(buf[18:headerSizeTCP]) t.Options = buf[headerSizeTCP : t.DataOffset<<2] - t.payload = buf[t.DataOffset<<2:] + t.Payload = buf[t.DataOffset<<2:] return nil } -func (t *TCPSegment) NextLayer() (string, []byte) { - return nextAppLayer(t.SrcPort, t.DstPort), t.payload +func (t *TCPSegment) NextLayer() Layer { + return ParseNextLayer(t.Payload, &t.SrcPort, &t.DstPort) } -func nextAppLayer(src, dst uint16) string { - var layer string - switch { - case src == 20 || dst == 20 || src == 21 || dst == 21: - layer = "FTP" - case src == 22 || dst == 22: - layer = "SSH" - case src == 53 || dst == 53: - layer = "DNS" - case src == 80 || dst == 80: - layer = "HTTP" - case src == 161 || dst == 161 || src == 162 || dst == 162: - layer = "SNMP" - case src == 443 || dst == 443: - layer = "TLS" - default: - layer = "" - } - return layer -} +func (t *TCPSegment) Name() string { return "TCP" } diff --git a/layers/tcp_test.go b/layers/tcp_test.go index b5627ce..3c924e1 100644 --- a/layers/tcp_test.go +++ b/layers/tcp_test.go @@ -46,7 +46,7 @@ func TestParseTCP(t *testing.T) { 0x02, 0x04, 0x05, 0xb4, 0x04, 0x02, 0x08, 0x0a, 0xac, 0xf8, 0x48, 0x3e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07, }, - payload: []byte{}, + Payload: []byte{}, } tcp := &TCPSegment{} packet, close := testPacket(t, "tcp") diff --git a/layers/tls.go b/layers/tls.go index ce46dc2..cc21780 100644 --- a/layers/tls.go +++ b/layers/tls.go @@ -581,7 +581,8 @@ func (t *TLSMessage) Parse(data []byte) error { return nil } -func (t *TLSMessage) NextLayer() (layer string, payload []byte) { return } +func (t *TLSMessage) NextLayer() Layer { return nil } +func (t *TLSMessage) Name() string { return "TLS" } func ctdesc(ct uint8) string { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-5 diff --git a/layers/udp.go b/layers/udp.go index ca6565c..f859da9 100644 --- a/layers/udp.go +++ b/layers/udp.go @@ -86,10 +86,12 @@ func (u *UDPSegment) Parse(data []byte) error { return u.UnmarshalBinary(data) } -func (u *UDPSegment) NextLayer() (string, []byte) { - return nextAppLayer(u.SrcPort, u.DstPort), u.Payload +func (u *UDPSegment) NextLayer() Layer { + return ParseNextLayer(u.Payload, &u.SrcPort, &u.DstPort) } +func (u *UDPSegment) Name() string { return "UDP" } + func CalculateUDPChecksum(data []byte) (uint16, error) { var sum uint16 udpLength := len(data) diff --git a/mshark.go b/mshark.go index 1d1f413..b0ede90 100644 --- a/mshark.go +++ b/mshark.go @@ -94,17 +94,10 @@ func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { var layerNum int mw.printPacket(next, layerNum) for { - name, data := next.NextLayer() - if name == "" || data == nil || len(data) == 0 { - return nil - } - next = layers.GetNextLayer(name) + next = next.NextLayer() if next == nil { return nil } - if err := next.Parse(data); err != nil { - return err - } layerNum++ mw.printPacket(next, layerNum) } From e9a6d6689bd47ab42852c53e445a73f1ce3236d2 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 8 Aug 2025 21:11:58 +0300 Subject: [PATCH 08/17] added more checks for parsing --- layers/ipv6.go | 11 +++++++++-- layers/layers.go | 12 +++++++++--- layers/snmp.go | 3 +++ layers/ssh.go | 12 +++++++++++- layers/tcp.go | 8 ++++++-- 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/layers/ipv6.go b/layers/ipv6.go index b7800f5..c1ddc0e 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -90,8 +90,15 @@ func (p *IPv6Packet) Parse(data []byte) error { p.NextHeader = buf[6] p.NextHeaderDesc = p.nextHeader() p.HopLimit = buf[7] - p.SrcIP, _ = netip.AddrFromSlice(buf[8:24]) - p.DstIP, _ = netip.AddrFromSlice(buf[24:headerSizeIPv6]) + var ok bool + p.SrcIP, ok = netip.AddrFromSlice(buf[8:24]) + if !ok { + return fmt.Errorf("malformed IPv6 address") + } + p.DstIP, ok = netip.AddrFromSlice(buf[24:headerSizeIPv6]) + if !ok { + return fmt.Errorf("malformed IPv6 address") + } p.Payload = buf[headerSizeIPv6:] return nil } diff --git a/layers/layers.go b/layers/layers.go index 14700fb..02683ff 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -88,6 +88,12 @@ func parseNextLayerFromBytes(data []byte) Layer { return next } } + if firstByte == 0x30 { + next = GetNextLayer("SNMP") + if err := next.Parse(buf); err == nil { + return next + } + } if len(buf) > 3 { b1 := binary.BigEndian.Uint16(buf[0:2]) b2 := binary.BigEndian.Uint16(buf[2:4]) @@ -138,7 +144,7 @@ func addrMatch(src, dst *uint16, ports []uint16) bool { return false } -func parseNextLayerFromAddress(data []byte, src, dst *uint16) Layer { +func parseNextLayerFromPorts(data []byte, src, dst *uint16) Layer { if len(data) == 0 { return nil } @@ -154,7 +160,7 @@ func parseNextLayerFromAddress(data []byte, src, dst *uint16) Layer { next = GetNextLayer("FTP") case addrMatch(src, dst, []uint16{22, 2222, 2200, 222, 2022}): next = GetNextLayer("SSH") - case addrMatch(src, dst, []uint16{443, 465, 993, 995, 8443, 9443, 10443, 8444}): + case addrMatch(src, dst, []uint16{443, 465, 993, 995, 8443, 9443, 10443, 8444, 5228}): next = GetNextLayer("TLS") default: return nil @@ -171,7 +177,7 @@ func ParseNextLayer(data []byte, src, dst *uint16) Layer { buf = append(buf, data...) var next Layer if src != nil || dst != nil { - if next = parseNextLayerFromAddress(buf, src, dst); next != nil { + if next = parseNextLayerFromPorts(buf, src, dst); next != nil { return next } } diff --git a/layers/snmp.go b/layers/snmp.go index 21458e0..0f6ad01 100644 --- a/layers/snmp.go +++ b/layers/snmp.go @@ -18,6 +18,9 @@ func (s *SNMPMessage) Summary() string { } func (s *SNMPMessage) Parse(data []byte) error { + if data[0] != 0x30 { + return fmt.Errorf("not ASN.1 SEQUENCE") + } buf := make([]byte, 0, len(data)) buf = append(buf, data...) s.Payload = buf diff --git a/layers/ssh.go b/layers/ssh.go index c007f78..2a98b7b 100644 --- a/layers/ssh.go +++ b/layers/ssh.go @@ -90,11 +90,18 @@ func (s *SSHMessage) Parse(data []byte) error { buf := make([]byte, 0, len(data)) buf = append(buf, data...) if bytes.HasSuffix(buf, crlf) { - s.Protocol = bytesToStr(bytes.TrimSuffix(buf, crlf)) + p := bytes.TrimSuffix(buf, crlf) + if !bytes.Contains(p, protoSSH) { + return fmt.Errorf("message should contain SSH-") + } + s.Protocol = bytesToStr(p) return nil } s.Messages = make([]*Message, 0, 3) for len(buf) > 0 { + if len(buf) < 4 { + return ErrSliceBounds + } m := &Message{} s.Messages = append(s.Messages, m) plen := binary.BigEndian.Uint32(buf[0:4]) @@ -102,6 +109,9 @@ func (s *SSHMessage) Parse(data []byte) error { m.Payload = buf break } + if len(buf) < 5 { + return ErrSliceBounds + } m.MesssageType = buf[5] if m.MesssageTypeDesc = mtypedesc(m.MesssageType); m.MesssageTypeDesc == "Unknown" { m.Payload = buf diff --git a/layers/tcp.go b/layers/tcp.go index 18a6267..790c22b 100644 --- a/layers/tcp.go +++ b/layers/tcp.go @@ -127,8 +127,12 @@ func (t *TCPSegment) Parse(data []byte) error { t.WindowSize = binary.BigEndian.Uint16(buf[14:16]) t.Checksum = binary.BigEndian.Uint16(buf[16:18]) t.UrgentPointer = binary.BigEndian.Uint16(buf[18:headerSizeTCP]) - t.Options = buf[headerSizeTCP : t.DataOffset<<2] - t.Payload = buf[t.DataOffset<<2:] + offset := t.DataOffset << 2 + if len(buf) < int(offset) { + return ErrSliceBounds + } + t.Options = buf[headerSizeTCP:offset] + t.Payload = buf[offset:] return nil } From adad59e3dc020854c00ed06242c34b9870e38ec1 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 8 Aug 2025 21:15:48 +0300 Subject: [PATCH 09/17] added additional check for IPv4 --- layers/ipv4.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/layers/ipv4.go b/layers/ipv4.go index 4d9caa5..4064735 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -185,6 +185,9 @@ func (p *IPv4Packet) UnmarshalBinary(data []byte) error { } if p.IHL > 5 { offset := headerSizeIPv4 + ((p.IHL - 5) << 2) + if int(offset) > len(buf) { + return ErrSliceBounds + } p.Options = buf[headerSizeIPv4:offset] p.Payload = buf[offset:] } else { From c4284e7fa31945cec88ddcd4e048bd4c3e568aea Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Sat, 9 Aug 2025 07:28:16 +0300 Subject: [PATCH 10/17] updated level name logic --- layers/arp.go | 2 +- layers/dns.go | 2 +- layers/ethernet.go | 4 +- layers/ftp.go | 2 +- layers/http.go | 2 +- layers/icmp.go | 2 +- layers/icmpv6.go | 2 +- layers/ipv4.go | 4 +- layers/ipv6.go | 4 +- layers/layers.go | 109 ++++++++++++++++++++++++++------------------- layers/snmp.go | 2 +- layers/ssh.go | 2 +- layers/tcp.go | 2 +- layers/tls.go | 2 +- layers/udp.go | 2 +- mshark.go | 2 +- 16 files changed, 82 insertions(+), 63 deletions(-) diff --git a/layers/arp.go b/layers/arp.go index 9feab7a..036f0c0 100644 --- a/layers/arp.go +++ b/layers/arp.go @@ -188,7 +188,7 @@ func (ap *ARPPacket) Parse(data []byte) error { } func (ap *ARPPacket) NextLayer() Layer { return nil } -func (ap *ARPPacket) Name() string { return "ARP" } +func (ap *ARPPacket) Name() LayerName { return LayerARP } func ptypedesc(pt uint16) string { var proto string diff --git a/layers/dns.go b/layers/dns.go index b9e3eda..b3d7085 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -288,7 +288,7 @@ func (d *DNSMessage) Parse(data []byte) error { func (d *DNSMessage) NextLayer() Layer { return nil } -func (d *DNSMessage) Name() string { return "DNS" } +func (d *DNSMessage) Name() LayerName { return LayerDNS } func (d *DNSMessage) printRecords() string { var sb strings.Builder diff --git a/layers/ethernet.go b/layers/ethernet.go index cbd3742..7d119d4 100644 --- a/layers/ethernet.go +++ b/layers/ethernet.go @@ -111,7 +111,7 @@ func (ef *EthernetFrame) Parse(data []byte) error { } func (ef *EthernetFrame) NextLayer() Layer { - if next := GetNextLayer(ef.EtherType.Desc); next != nil { + if next := GetNextLayer(LayerName(ef.EtherType.Desc)); next != nil { if err := next.Parse(ef.Payload); err == nil { return next } @@ -119,7 +119,7 @@ func (ef *EthernetFrame) NextLayer() Layer { return ParseNextLayer(ef.Payload, nil, nil) } -func (ef *EthernetFrame) Name() string { return "ETH" } +func (ef *EthernetFrame) Name() LayerName { return LayerETH } func ethertypedesc(et EtherType) string { var etdesc string diff --git a/layers/ftp.go b/layers/ftp.go index 680ec8c..9ecf6d4 100644 --- a/layers/ftp.go +++ b/layers/ftp.go @@ -42,4 +42,4 @@ func (f *FTPMessage) Parse(data []byte) error { } func (f *FTPMessage) NextLayer() Layer { return nil } -func (f *FTPMessage) Name() string { return "FTP" } +func (f *FTPMessage) Name() LayerName { return LayerFTP } diff --git a/layers/http.go b/layers/http.go index fc426b8..535394c 100644 --- a/layers/http.go +++ b/layers/http.go @@ -87,7 +87,7 @@ func (h *HTTPMessage) Parse(data []byte) error { } func (h *HTTPMessage) NextLayer() Layer { return nil } -func (h *HTTPMessage) Name() string { return "HTTP" } +func (h *HTTPMessage) Name() LayerName { return LayerHTTP } type HTTPRequestWrapper struct { Request HTTPRequest `json:"http_request"` diff --git a/layers/icmp.go b/layers/icmp.go index 4fe7b02..495b8ea 100644 --- a/layers/icmp.go +++ b/layers/icmp.go @@ -69,7 +69,7 @@ func (i *ICMPSegment) Parse(data []byte) error { } func (i *ICMPSegment) NextLayer() Layer { return nil } -func (i *ICMPSegment) Name() string { return "ICMP" } +func (i *ICMPSegment) Name() LayerName { return LayerICMP } func (i *ICMPSegment) typecode() (string, string) { // https://en.wikipedia.org/wiki/Internet_Control_Message_Protocol diff --git a/layers/icmpv6.go b/layers/icmpv6.go index 2f3cdb6..51ab311 100644 --- a/layers/icmpv6.go +++ b/layers/icmpv6.go @@ -71,7 +71,7 @@ func (i *ICMPv6Segment) Parse(data []byte) error { } func (i *ICMPv6Segment) NextLayer() Layer { return nil } -func (i *ICMPv6Segment) Name() string { return "ICMPv6" } +func (i *ICMPv6Segment) Name() LayerName { return LayerICMPv6 } func (i *ICMPv6Segment) typecode() (string, string) { // https://en.wikipedia.org/wiki/ICMPv6 diff --git a/layers/ipv4.go b/layers/ipv4.go index 4064735..f15f953 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -218,7 +218,7 @@ func protodesc(proto IPProto) string { } func (p *IPv4Packet) NextLayer() Layer { - if next := GetNextLayer(p.Protocol.Desc); next != nil { + if next := GetNextLayer(LayerName(p.Protocol.Desc)); next != nil { if err := next.Parse(p.Payload); err == nil { return next } @@ -226,7 +226,7 @@ func (p *IPv4Packet) NextLayer() Layer { return ParseNextLayer(p.Payload, nil, nil) } -func (p *IPv4Packet) Name() string { return "IPv4" } +func (p *IPv4Packet) Name() LayerName { return LayerIPv4 } func dscpdesc(dscp uint8) string { // https://en.wikipedia.org/wiki/Differentiated_services diff --git a/layers/ipv6.go b/layers/ipv6.go index c1ddc0e..66366f2 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -120,7 +120,7 @@ func (p *IPv6Packet) nextLayer() string { } func (p *IPv6Packet) NextLayer() Layer { - if next := GetNextLayer(p.nextLayer()); next != nil { + if next := GetNextLayer(LayerName(p.nextLayer())); next != nil { if err := next.Parse(p.Payload); err == nil { return next } @@ -128,7 +128,7 @@ func (p *IPv6Packet) NextLayer() Layer { return ParseNextLayer(p.Payload, nil, nil) } -func (p *IPv6Packet) Name() string { return "IPv6" } +func (p *IPv6Packet) Name() LayerName { return LayerIPv6 } func (p *IPv6Packet) nextHeader() string { // https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers diff --git a/layers/layers.go b/layers/layers.go index 02683ff..8e341b0 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -10,21 +10,40 @@ import ( const maxLenSummary = 110 -var Layers = []string{ - "ETH", - "IPv4", - "IPv6", - "ARP", - "TCP", - "UDP", - "ICMP", - "ICMPv6", - "DNS", - "FTP", - "HTTP", - "SNMP", - "SSH", - "TLS", +type LayerName string + +const ( + LayerETH LayerName = "ETH" + LayerIPv4 LayerName = "IPv4" + LayerIPv6 LayerName = "IPv6" + LayerARP LayerName = "ARP" + LayerTCP LayerName = "TCP" + LayerUDP LayerName = "UDP" + LayerICMP LayerName = "ICMP" + LayerICMPv6 LayerName = "ICMPv6" + LayerDNS LayerName = "DNS" + LayerFTP LayerName = "FTP" + LayerHTTP LayerName = "HTTP" + LayerSNMP LayerName = "SNMP" + LayerSSH LayerName = "SSH" + LayerTLS LayerName = "TLS" +) + +var Layers = []LayerName{ + LayerETH, + LayerIPv4, + LayerIPv6, + LayerTLS, + LayerHTTP, + LayerDNS, + LayerARP, + LayerTCP, + LayerUDP, + LayerICMP, + LayerICMPv6, + LayerSNMP, + LayerSSH, + LayerFTP, } var ( @@ -46,7 +65,7 @@ type Layer interface { Parse(data []byte) error NextLayer() Layer Summary() string - Name() string + Name() LayerName } func parseNextLayerFallback(data []byte) Layer { @@ -71,25 +90,25 @@ func parseNextLayerFromBytes(data []byte) Layer { var next Layer firstByte := buf[0] if firstByte >= 0x45 && firstByte <= 0x4F { - next = GetNextLayer("IPv4") + next = GetNextLayer(LayerIPv4) if err := next.Parse(buf); err == nil { return next } } if firstByte>>4 == 6 { - next = GetNextLayer("IPv6") + next = GetNextLayer(LayerIPv6) if err := next.Parse(buf); err == nil { return next } } if firstByte == HandshakeTLSVal { - next = GetNextLayer("TLS") + next = GetNextLayer(LayerTLS) if err := next.Parse(buf); err == nil { return next } } if firstByte == 0x30 { - next = GetNextLayer("SNMP") + next = GetNextLayer(LayerSNMP) if err := next.Parse(buf); err == nil { return next } @@ -98,7 +117,7 @@ func parseNextLayerFromBytes(data []byte) Layer { b1 := binary.BigEndian.Uint16(buf[0:2]) b2 := binary.BigEndian.Uint16(buf[2:4]) if b1 == 1 && (b2 == 0x0800 || b2 == 0x86dd) { - next = GetNextLayer("ARP") + next = GetNextLayer(LayerARP) if err := next.Parse(buf); err == nil { return next } @@ -107,20 +126,20 @@ func parseNextLayerFromBytes(data []byte) Layer { if len(buf) > 15 { b1 := binary.BigEndian.Uint16(buf[12:14]) if b1 == 0x0806 || b1 == 0x0800 || b1 == 0x86dd { - next = GetNextLayer("ETH") + next = GetNextLayer(LayerETH) if err := next.Parse(buf); err == nil { return next } } } if bytes.Contains(buf, protohttp10) || bytes.Contains(buf, protohttp11) { - next = GetNextLayer("HTTP") + next = GetNextLayer(LayerHTTP) if err := next.Parse(buf); err == nil { return next } } if bytes.Contains(buf, protoSSH) { - next = GetNextLayer("SSH") + next = GetNextLayer(LayerSSH) if err := next.Parse(buf); err == nil { return next } @@ -151,17 +170,17 @@ func parseNextLayerFromPorts(data []byte, src, dst *uint16) Layer { var next Layer switch { case addrMatch(src, dst, []uint16{53, 5353, 853, 5355}): - next = GetNextLayer("DNS") + next = GetNextLayer(LayerDNS) case addrMatch(src, dst, []uint16{80, 8080, 8000, 8888, 81, 591, 5911}): - next = GetNextLayer("HTTP") + next = GetNextLayer(LayerHTTP) case addrMatch(src, dst, []uint16{161, 162, 10161, 10162, 1161, 2161}): - next = GetNextLayer("SNMP") + next = GetNextLayer(LayerSNMP) case addrMatch(src, dst, []uint16{21, 20, 2121, 8021}): - next = GetNextLayer("FTP") + next = GetNextLayer(LayerFTP) case addrMatch(src, dst, []uint16{22, 2222, 2200, 222, 2022}): - next = GetNextLayer("SSH") + next = GetNextLayer(LayerSSH) case addrMatch(src, dst, []uint16{443, 465, 993, 995, 8443, 9443, 10443, 8444, 5228}): - next = GetNextLayer("TLS") + next = GetNextLayer(LayerTLS) default: return nil } @@ -187,35 +206,35 @@ func ParseNextLayer(data []byte, src, dst *uint16) Layer { return parseNextLayerFallback(buf) } -func GetNextLayer(layer string) Layer { +func GetNextLayer(layer LayerName) Layer { switch layer { - case "ETH": + case LayerETH: return &EthernetFrame{} - case "IPv4": + case LayerIPv4: return &IPv4Packet{} - case "IPv6": + case LayerIPv6: return &IPv6Packet{} - case "ARP": + case LayerARP: return &ARPPacket{} - case "TCP": + case LayerTCP: return &TCPSegment{} - case "UDP": + case LayerUDP: return &UDPSegment{} - case "ICMP": + case LayerICMP: return &ICMPSegment{} - case "ICMPv6": + case LayerICMPv6: return &ICMPv6Segment{} - case "DNS": + case LayerDNS: return &DNSMessage{} - case "FTP": + case LayerFTP: return &FTPMessage{} - case "HTTP": + case LayerHTTP: return &HTTPMessage{} - case "SNMP": + case LayerSNMP: return &SNMPMessage{} - case "SSH": + case LayerSSH: return &SSHMessage{} - case "TLS": + case LayerTLS: return &TLSMessage{} default: return nil diff --git a/layers/snmp.go b/layers/snmp.go index 0f6ad01..5bd2d23 100644 --- a/layers/snmp.go +++ b/layers/snmp.go @@ -28,4 +28,4 @@ func (s *SNMPMessage) Parse(data []byte) error { } func (s *SNMPMessage) NextLayer() Layer { return nil } -func (s *SNMPMessage) Name() string { return "SNMP" } +func (s *SNMPMessage) Name() LayerName { return LayerSNMP } diff --git a/layers/ssh.go b/layers/ssh.go index 2a98b7b..ea7e98d 100644 --- a/layers/ssh.go +++ b/layers/ssh.go @@ -132,7 +132,7 @@ func (s *SSHMessage) Parse(data []byte) error { } func (s *SSHMessage) NextLayer() Layer { return nil } -func (s *SSHMessage) Name() string { return "SSH" } +func (s *SSHMessage) Name() LayerName { return LayerSSH } // https://www.iana.org/assignments/ssh-parameters/ssh-parameters.xhtml func mtypedesc(mtype uint8) string { diff --git a/layers/tcp.go b/layers/tcp.go index 790c22b..374cb23 100644 --- a/layers/tcp.go +++ b/layers/tcp.go @@ -140,4 +140,4 @@ func (t *TCPSegment) NextLayer() Layer { return ParseNextLayer(t.Payload, &t.SrcPort, &t.DstPort) } -func (t *TCPSegment) Name() string { return "TCP" } +func (t *TCPSegment) Name() LayerName { return LayerTCP } diff --git a/layers/tls.go b/layers/tls.go index cc21780..8e74bcd 100644 --- a/layers/tls.go +++ b/layers/tls.go @@ -582,7 +582,7 @@ func (t *TLSMessage) Parse(data []byte) error { } func (t *TLSMessage) NextLayer() Layer { return nil } -func (t *TLSMessage) Name() string { return "TLS" } +func (t *TLSMessage) Name() LayerName { return LayerTLS } func ctdesc(ct uint8) string { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-5 diff --git a/layers/udp.go b/layers/udp.go index f859da9..d2ad1a5 100644 --- a/layers/udp.go +++ b/layers/udp.go @@ -90,7 +90,7 @@ func (u *UDPSegment) NextLayer() Layer { return ParseNextLayer(u.Payload, &u.SrcPort, &u.DstPort) } -func (u *UDPSegment) Name() string { return "UDP" } +func (u *UDPSegment) Name() LayerName { return LayerUDP } func CalculateUDPChecksum(data []byte) (uint16, error) { var sum uint16 diff --git a/mshark.go b/mshark.go index b0ede90..236f529 100644 --- a/mshark.go +++ b/mshark.go @@ -84,7 +84,7 @@ func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { mw.packets++ fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05.000000-0700")) fmt.Fprintln(mw.w, packetDelimeter) - next := layers.GetNextLayer("ETH") + next := layers.GetNextLayer(layers.LayerETH) if next == nil { return nil } From 514b6c3c2f0127d3ef2f93460ea708be75c5444e Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Sat, 9 Aug 2025 14:37:08 +0300 Subject: [PATCH 11/17] added more checks for FTP and SNMP --- layers/ethernet.go | 2 +- layers/ftp.go | 9 +++++++++ layers/ipv4.go | 2 +- layers/ipv6.go | 2 +- layers/layers.go | 48 ++++++++++++++++++++++++++++++---------------- layers/snmp.go | 10 +++++++--- mshark.go | 2 +- 7 files changed, 51 insertions(+), 24 deletions(-) diff --git a/layers/ethernet.go b/layers/ethernet.go index 7d119d4..5e31870 100644 --- a/layers/ethernet.go +++ b/layers/ethernet.go @@ -111,7 +111,7 @@ func (ef *EthernetFrame) Parse(data []byte) error { } func (ef *EthernetFrame) NextLayer() Layer { - if next := GetNextLayer(LayerName(ef.EtherType.Desc)); next != nil { + if next := GetLayer(LayerName(ef.EtherType.Desc)); next != nil { if err := next.Parse(ef.Payload); err == nil { return next } diff --git a/layers/ftp.go b/layers/ftp.go index 9ecf6d4..63d6592 100644 --- a/layers/ftp.go +++ b/layers/ftp.go @@ -23,6 +23,9 @@ func (f *FTPMessage) Summary() string { func (f *FTPMessage) Parse(data []byte) error { buf := make([]byte, 0, len(data)) buf = append(buf, data...) + if !checkFTP(buf) { + return fmt.Errorf("malformed ftp message") + } f.summary = nil f.data = nil sp := bytes.Split(buf, crlf) @@ -43,3 +46,9 @@ func (f *FTPMessage) Parse(data []byte) error { func (f *FTPMessage) NextLayer() Layer { return nil } func (f *FTPMessage) Name() LayerName { return LayerFTP } + +func checkFTP(data []byte) bool { + return (len(data) >= 4 && isDigit(data[0]) && isDigit(data[1]) && + isDigit(data[2]) && (data[3] == ' ' || data[3] == '-')) || + (len(data) >= 3 && isUpper(data[0]) && isUpper(data[1]) && isUpper(data[2])) +} diff --git a/layers/ipv4.go b/layers/ipv4.go index f15f953..94c3b29 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -218,7 +218,7 @@ func protodesc(proto IPProto) string { } func (p *IPv4Packet) NextLayer() Layer { - if next := GetNextLayer(LayerName(p.Protocol.Desc)); next != nil { + if next := GetLayer(LayerName(p.Protocol.Desc)); next != nil { if err := next.Parse(p.Payload); err == nil { return next } diff --git a/layers/ipv6.go b/layers/ipv6.go index 66366f2..cd54ee4 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -120,7 +120,7 @@ func (p *IPv6Packet) nextLayer() string { } func (p *IPv6Packet) NextLayer() Layer { - if next := GetNextLayer(LayerName(p.nextLayer())); next != nil { + if next := GetLayer(LayerName(p.nextLayer())); next != nil { if err := next.Parse(p.Payload); err == nil { return next } diff --git a/layers/layers.go b/layers/layers.go index 8e341b0..d4f6dad 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -73,7 +73,7 @@ func parseNextLayerFallback(data []byte) Layer { return nil } for _, layer := range Layers { - next := GetNextLayer(layer) + next := GetLayer(layer) if err := next.Parse(data); err == nil { return next } @@ -90,25 +90,25 @@ func parseNextLayerFromBytes(data []byte) Layer { var next Layer firstByte := buf[0] if firstByte >= 0x45 && firstByte <= 0x4F { - next = GetNextLayer(LayerIPv4) + next = GetLayer(LayerIPv4) if err := next.Parse(buf); err == nil { return next } } if firstByte>>4 == 6 { - next = GetNextLayer(LayerIPv6) + next = GetLayer(LayerIPv6) if err := next.Parse(buf); err == nil { return next } } if firstByte == HandshakeTLSVal { - next = GetNextLayer(LayerTLS) + next = GetLayer(LayerTLS) if err := next.Parse(buf); err == nil { return next } } - if firstByte == 0x30 { - next = GetNextLayer(LayerSNMP) + if checkFTP(buf) { + next = GetLayer(LayerFTP) if err := next.Parse(buf); err == nil { return next } @@ -117,29 +117,35 @@ func parseNextLayerFromBytes(data []byte) Layer { b1 := binary.BigEndian.Uint16(buf[0:2]) b2 := binary.BigEndian.Uint16(buf[2:4]) if b1 == 1 && (b2 == 0x0800 || b2 == 0x86dd) { - next = GetNextLayer(LayerARP) + next = GetLayer(LayerARP) if err := next.Parse(buf); err == nil { return next } } } + if checkSNMP(buf) { + next = GetLayer(LayerSNMP) + if err := next.Parse(buf); err == nil { + return next + } + } if len(buf) > 15 { b1 := binary.BigEndian.Uint16(buf[12:14]) if b1 == 0x0806 || b1 == 0x0800 || b1 == 0x86dd { - next = GetNextLayer(LayerETH) + next = GetLayer(LayerETH) if err := next.Parse(buf); err == nil { return next } } } if bytes.Contains(buf, protohttp10) || bytes.Contains(buf, protohttp11) { - next = GetNextLayer(LayerHTTP) + next = GetLayer(LayerHTTP) if err := next.Parse(buf); err == nil { return next } } if bytes.Contains(buf, protoSSH) { - next = GetNextLayer(LayerSSH) + next = GetLayer(LayerSSH) if err := next.Parse(buf); err == nil { return next } @@ -170,17 +176,17 @@ func parseNextLayerFromPorts(data []byte, src, dst *uint16) Layer { var next Layer switch { case addrMatch(src, dst, []uint16{53, 5353, 853, 5355}): - next = GetNextLayer(LayerDNS) + next = GetLayer(LayerDNS) case addrMatch(src, dst, []uint16{80, 8080, 8000, 8888, 81, 591, 5911}): - next = GetNextLayer(LayerHTTP) + next = GetLayer(LayerHTTP) case addrMatch(src, dst, []uint16{161, 162, 10161, 10162, 1161, 2161}): - next = GetNextLayer(LayerSNMP) + next = GetLayer(LayerSNMP) case addrMatch(src, dst, []uint16{21, 20, 2121, 8021}): - next = GetNextLayer(LayerFTP) + next = GetLayer(LayerFTP) case addrMatch(src, dst, []uint16{22, 2222, 2200, 222, 2022}): - next = GetNextLayer(LayerSSH) + next = GetLayer(LayerSSH) case addrMatch(src, dst, []uint16{443, 465, 993, 995, 8443, 9443, 10443, 8444, 5228}): - next = GetNextLayer(LayerTLS) + next = GetLayer(LayerTLS) default: return nil } @@ -206,7 +212,7 @@ func ParseNextLayer(data []byte, src, dst *uint16) Layer { return parseNextLayerFallback(buf) } -func GetNextLayer(layer LayerName) Layer { +func GetLayer(layer LayerName) Layer { switch layer { case LayerETH: return &EthernetFrame{} @@ -262,3 +268,11 @@ func add16WithCarryWrapAround(x, y uint16) uint16 { sum32 = (sum32 & 0xFFFF) + (sum32 >> 16) return uint16(sum32) } + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} + +func isUpper(b byte) bool { + return b >= 'A' && b <= 'Z' +} diff --git a/layers/snmp.go b/layers/snmp.go index 5bd2d23..b232d7e 100644 --- a/layers/snmp.go +++ b/layers/snmp.go @@ -18,14 +18,18 @@ func (s *SNMPMessage) Summary() string { } func (s *SNMPMessage) Parse(data []byte) error { - if data[0] != 0x30 { - return fmt.Errorf("not ASN.1 SEQUENCE") - } buf := make([]byte, 0, len(data)) buf = append(buf, data...) + if !checkSNMP(buf) { + return fmt.Errorf("not ASN.1 SEQUENCE") + } s.Payload = buf return nil } func (s *SNMPMessage) NextLayer() Layer { return nil } func (s *SNMPMessage) Name() LayerName { return LayerSNMP } + +func checkSNMP(data []byte) bool { + return len(data) > 6 && data[0] == 0x30 && (data[2] == 0x02 || data[2] == 0x04) +} diff --git a/mshark.go b/mshark.go index 236f529..5b5af40 100644 --- a/mshark.go +++ b/mshark.go @@ -84,7 +84,7 @@ func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { mw.packets++ fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05.000000-0700")) fmt.Fprintln(mw.w, packetDelimeter) - next := layers.GetNextLayer(layers.LayerETH) + next := layers.GetLayer(layers.LayerETH) if next == nil { return nil } From db96ad578989b70987c9a4e71949abcacae5c8e5 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Mon, 11 Aug 2025 08:05:48 +0300 Subject: [PATCH 12/17] added more bounds checks, removed fallback parsing, refactored creating of packet connection --- layers/arp.go | 9 +++ layers/dns.go | 157 +++++++++++++++++++++++++-------------------- layers/ethernet.go | 5 +- layers/http.go | 10 +-- layers/ipv4.go | 19 +++++- layers/ipv6.go | 11 +++- layers/layers.go | 5 +- layers/tcp.go | 2 +- layers/tls.go | 14 +++- layers/udp.go | 18 +++--- mshark.go | 42 ++---------- network/network.go | 57 ++++++++++++++++ 12 files changed, 220 insertions(+), 129 deletions(-) diff --git a/layers/arp.go b/layers/arp.go index 036f0c0..f69a07e 100644 --- a/layers/arp.go +++ b/layers/arp.go @@ -167,12 +167,21 @@ func (ap *ARPPacket) UnmarshalBinary(data []byte) error { hoffset := 8 + ap.Hlen ap.SenderMAC = net.HardwareAddr(buf[8:hoffset]) poffset := hoffset + ap.Plen + if int(poffset) > len(buf) { + return ErrSliceBounds + } var ok bool ap.SenderIP, ok = netip.AddrFromSlice(buf[hoffset:poffset]) if !ok { return fmt.Errorf("failed parsing sender IP address") } + if int(poffset+ap.Hlen) > len(buf) { + return ErrSliceBounds + } ap.TargetMAC = net.HardwareAddr(buf[poffset : poffset+ap.Hlen]) + if int(poffset+ap.Hlen+ap.Plen) > len(buf) { + return ErrSliceBounds + } ap.TargetIP, ok = netip.AddrFromSlice(buf[poffset+ap.Hlen : poffset+ap.Hlen+ap.Plen]) if !ok { return fmt.Errorf("failed parsing target IP address") diff --git a/layers/dns.go b/layers/dns.go index b3d7085..2526629 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -3,30 +3,29 @@ package layers import ( "encoding/binary" "encoding/hex" + "encoding/json" "fmt" "net/netip" "strings" ) -// TODO (shadowy-pycoder): add MarshalJSON - const headerSizeDNS = 12 type DNSFlags struct { - Raw uint16 - QR uint8 // Indicates if the message is a query (0) or a reply (1). - QRDesc string // Query (0) or Reply (1) - OPCode uint8 // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5 - OPCodeDesc string - AA uint8 // Authoritative Answer, in a response, indicates if the DNS server is authoritative for the queried hostname. - TC uint8 // TrunCation, indicates that this message was truncated due to excessive length. - RD uint8 // Recursion Desired, indicates if the client means a recursive query. - RA uint8 // Recursion Available, in a response, indicates if the replying DNS server supports recursion. - Z uint8 // Zero, reserved for future use. - AU uint8 // Indicates if answer/authority portion was authenticated by the server. - NA uint8 // Indicates if non-authenticated data is accepatable. - RCode uint8 // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 - RCodeDesc string + Raw uint16 `json:"raw"` + QR uint8 `json:"qr"` // Indicates if the message is a query (0) or a reply (1). + QRDesc string `json:"qrdesc"` // Query (0) or Reply (1) + OPCode uint8 `json:"opcode"` // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5 + OPCodeDesc string `json:"opcodedesc"` + AA uint8 `json:"aa"` // Authoritative Answer, in a response, indicates if the DNS server is authoritative for the queried hostname. + TC uint8 `json:"tc"` // TrunCation, indicates that this message was truncated due to excessive length. + RD uint8 `json:"rd"` // Recursion Desired, indicates if the client means a recursive query. + RA uint8 `json:"ra"` // Recursion Available, in a response, indicates if the replying DNS server supports recursion. + Z uint8 `json:"z"` // Zero, reserved for future use. + AU uint8 `json:"au"` // Indicates if answer/authority portion was authenticated by the server. + NA uint8 `json:"na"` // Indicates if non-authenticated data is accepatable. + RCode uint8 `json:"rcode"` // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 + RCodeDesc string `json:"rcodedesc"` } func (df *DNSFlags) String() string { @@ -173,16 +172,16 @@ func rcdesc(rcode uint8) string { } type DNSMessage struct { - TransactionID uint16 // Used for matching response to queries. - Flags *DNSFlags // Flags specify the requested operation and a response code. - QDCount uint16 // Count of entries in the queries section. - ANCount uint16 // Count of entries in the answers section. - NSCount uint16 // Count of entries in the authority section. - ARCount uint16 // Count of entries in the additional section. - Questions []*QueryEntry - AnswerRRs []*ResourceRecord - AuthorityRRs []*ResourceRecord - AdditionalRRs []*ResourceRecord + TransactionID uint16 `json:"transaction-id"` // Used for matching response to queries. + Flags *DNSFlags `json:"flags,omitempty"` // Flags specify the requested operation and a response code. + QDCount uint16 `json:"questions-count"` // Count of entries in the queries section. + ANCount uint16 `json:"answer-rrs-count"` // Count of entries in the answers section. + NSCount uint16 `json:"authority-rrs-count"` // Count of entries in the authority section. + ARCount uint16 `json:"additional-rrs-count"` // Count of entries in the additional section. + Questions []*QueryEntry `json:"questions,omitempty"` + AnswerRRs []*ResourceRecord `json:"answers,omitempty"` + AuthorityRRs []*ResourceRecord `json:"authoritative-nameservers,omitempty"` + AdditionalRRs []*ResourceRecord `json:"additional-records,omitempty"` } func (d *DNSMessage) String() string { @@ -313,15 +312,31 @@ func (d *DNSMessage) printRecords() string { if d.ARCount > 0 { sb.WriteString("- Additional records:\n") for _, rec := range d.AdditionalRRs { - sb.WriteString(rec.String()) + sb.WriteString(strings.TrimSuffix(rec.String(), "\n")) } } - return strings.TrimSuffix(sb.String(), "\n") + return sb.String() +} + +type dnsMessageAlias DNSMessage + +type dnsQueryWrapper struct { + Query *dnsMessageAlias `json:"dns_query"` +} +type dnsReplyWrapper struct { + Reply *dnsMessageAlias `json:"dns_reply"` +} + +func (d *DNSMessage) MarshalJSON() ([]byte, error) { + if d.Flags.QR == 0 { + return json.Marshal(&dnsQueryWrapper{Query: (*dnsMessageAlias)(d)}) + } + return json.Marshal(&dnsReplyWrapper{Reply: (*dnsMessageAlias)(d)}) } type RecordClass struct { - Name string - Val uint16 + Name string `json:"name"` + Val uint16 `json:"val"` } func (c *RecordClass) String() string { @@ -351,8 +366,8 @@ func className(cls uint16) string { } type RecordType struct { - Name string - Val uint16 + Name string `json:"name"` + Val uint16 `json:"val"` } func (rt *RecordType) String() string { @@ -391,12 +406,12 @@ func typeName(typ uint16) string { } type ResourceRecord struct { - Name string // Name of the node to which this record pertains. - Type *RecordType // Type of RR in numeric form. - Class *RecordClass // Class code. - TTL uint32 // Count of seconds that the RR stays valid. - RDLength uint16 // Length of RData field (specified in octets). - RData fmt.Stringer // Additional RR-specific data. + Name string `json:"name"` // Name of the node to which this record pertains. + Type *RecordType `json:"record-type"` // Type of RR in numeric form. + Class *RecordClass `json:"record-class"` // Class code. + TTL uint32 `json:"ttl"` // Count of seconds that the RR stays valid. + RDLength uint16 `json:"rdata-length"` // Length of RData field (specified in octets). + RData fmt.Stringer `json:"rdata"` // Additional RR-specific data. } func (rt *ResourceRecord) String() string { @@ -453,9 +468,9 @@ func (rt *ResourceRecord) Summary() string { } type QueryEntry struct { - Name string // Name of the node to which this record pertains. - Type *RecordType // Type of RR in numeric form. - Class *RecordClass // Class code. + Name string `json:"name"` // Name of the node to which this record pertains. + Type *RecordType `json:"record-type"` // Type of RR in numeric form. + Class *RecordClass `json:"record-class"` // Class code. } func (qe *QueryEntry) String() string { @@ -467,7 +482,7 @@ func (qe *QueryEntry) String() string { } type RDataA struct { - Address netip.Addr + Address netip.Addr `json:"address"` } func (d *RDataA) String() string { @@ -475,7 +490,7 @@ func (d *RDataA) String() string { } type RDataNS struct { - NsdName string + NsdName string `json:"ns"` } func (d *RDataNS) String() string { @@ -483,7 +498,7 @@ func (d *RDataNS) String() string { } type RDataCNAME struct { - CName string + CName string `json:"cname"` } func (d *RDataCNAME) String() string { @@ -491,13 +506,13 @@ func (d *RDataCNAME) String() string { } type RDataSOA struct { - PrimaryNS string - RespAuthorityMailbox string - SerialNumber uint32 - RefreshInterval uint32 - RetryInterval uint32 - ExpireLimit uint32 - MinimumTTL uint32 + PrimaryNS string `json:"primary-nameserver"` + RespAuthorityMailbox string `json:"responsible-authority-mailbox"` + SerialNumber uint32 `json:"serial-number"` + RefreshInterval uint32 `json:"refresh-interval"` + RetryInterval uint32 `json:"retry-interval"` + ExpireLimit uint32 `json:"expire-limit"` + MinimumTTL uint32 `json:"minimum-ttl"` } func (d *RDataSOA) String() string { @@ -518,8 +533,8 @@ func (d *RDataSOA) String() string { } type RDataMX struct { - Preference uint16 - Exchange string + Preference uint16 `json:"preference"` + Exchange string `json:"exchange"` } func (d *RDataMX) String() string { @@ -527,7 +542,7 @@ func (d *RDataMX) String() string { } type RDataTXT struct { - TxtData string + TxtData string `json:"txt-data"` } func (d *RDataTXT) String() string { @@ -535,7 +550,7 @@ func (d *RDataTXT) String() string { } type RDataAAAA struct { - Address netip.Addr + Address netip.Addr `json:"address"` } func (d *RDataAAAA) String() string { @@ -543,11 +558,11 @@ func (d *RDataAAAA) String() string { } type RDataOPT struct { - UDPPayloadSize uint16 - HigherBitsExtRCode uint8 - EDNSVer uint8 - Z uint16 - DataLen uint16 + UDPPayloadSize uint16 `json:"udp-payload-size"` + HigherBitsExtRCode uint8 `json:"higer-bits-in-extended-rcode"` + EDNSVer uint8 `json:"edns0-version"` + Z uint16 `json:"z"` + DataLen uint16 `json:"data-length"` } func (d *RDataOPT) String() string { @@ -565,8 +580,8 @@ func (d *RDataOPT) String() string { } type SvcParamKey struct { - Val uint16 - Desc string + Val uint16 `json:"val"` + Desc string `json:"desc"` } // https://www.iana.org/assignments/dns-svcb/dns-svcb.xhtml @@ -608,9 +623,9 @@ func (spk *SvcParamKey) String() string { } type SvcParam struct { - Key *SvcParamKey - Length uint16 - Value []byte // TODO: add proper parsing + Key *SvcParamKey `json:"svc-param-key"` + Length uint16 `json:"svc-param-value-length"` + Value []byte `json:"svc-param-value"` // TODO: add proper parsing } func newSvcParam(data []byte) (*SvcParam, []byte, error) { @@ -639,10 +654,10 @@ func (sp *SvcParam) String() string { } type RDataHTTPS struct { - SvcPriority uint16 - Length int - TargetName string - SvcParams []*SvcParam + SvcPriority uint16 `json:"svc-priority"` + Length int `json:"length"` + TargetName string `json:"target-name"` + SvcParams []*SvcParam `json:"svc-params"` } func (d *RDataHTTPS) printSvcParams() string { @@ -668,7 +683,7 @@ func (d *RDataHTTPS) String() string { } type RDataUnknown struct { - Data string + Data string `json:"data"` } func (d *RDataUnknown) String() string { @@ -681,7 +696,7 @@ func (d *RDataUnknown) String() string { func extractDomain(payload, tail []byte) (string, []byte, error) { // see https://brunoscheufler.com/blog/2024-05-12-building-a-dns-message-parser#domain-names var domainName string - for { + for len(tail) > 0 { blen := tail[0] if blen>>6 == 0b11 { if len(tail) < 2 { diff --git a/layers/ethernet.go b/layers/ethernet.go index 5e31870..fe9b147 100644 --- a/layers/ethernet.go +++ b/layers/ethernet.go @@ -98,6 +98,9 @@ func (ef *EthernetFrame) UnmarshalBinary(data []byte) error { ef.SrcMAC = net.HardwareAddr(buf[6:12]) et := EtherType(binary.BigEndian.Uint16(buf[12:14])) etdesc := ethertypedesc(et) + if etdesc == "Unknown" { + return fmt.Errorf("failed determining Ethernet type") + } ef.EtherType = &EthernetType{Val: et, Desc: etdesc} ef.Payload = buf[headerSizeEthernet:] ef.DstVendor = oui.VendorWithMAC(ef.DstMAC) @@ -131,7 +134,7 @@ func ethertypedesc(et EtherType) string { case EtherTypeIPv6: etdesc = "IPv6" default: - etdesc = "" + etdesc = "Unknown" } return etdesc } diff --git a/layers/http.go b/layers/http.go index 535394c..1141abd 100644 --- a/layers/http.go +++ b/layers/http.go @@ -64,7 +64,7 @@ func (h *HTTPMessage) Parse(data []byte) error { if !bytes.Contains(buf, protohttp10) && !bytes.Contains(buf, protohttp11) { h.Request = nil h.Response = nil - return nil + return fmt.Errorf("message does not contain protocol") } reader := bufio.NewReader(bytes.NewReader(buf)) if bytes.HasPrefix(buf, protohttp11) || bytes.HasPrefix(buf, protohttp10) { @@ -90,7 +90,7 @@ func (h *HTTPMessage) NextLayer() Layer { return nil } func (h *HTTPMessage) Name() LayerName { return LayerHTTP } type HTTPRequestWrapper struct { - Request HTTPRequest `json:"http_request"` + Request *HTTPRequest `json:"http_request"` } type HTTPRequest struct { @@ -103,7 +103,7 @@ type HTTPRequest struct { } type HTTPResponseWrapper struct { - Response HTTPResponse `json:"http_response"` + Response *HTTPResponse `json:"http_response"` } type HTTPResponse struct { @@ -115,7 +115,7 @@ type HTTPResponse struct { func (h *HTTPMessage) MarshalJSON() ([]byte, error) { if h.Request != nil { - return json.Marshal(&HTTPRequestWrapper{Request: HTTPRequest{ + return json.Marshal(&HTTPRequestWrapper{Request: &HTTPRequest{ Host: h.Request.Host, URI: h.Request.RequestURI, Method: h.Request.Method, @@ -124,7 +124,7 @@ func (h *HTTPMessage) MarshalJSON() ([]byte, error) { Header: h.Request.Header, }}) } else if h.Response != nil { - return json.Marshal(&HTTPResponseWrapper{Response: HTTPResponse{ + return json.Marshal(&HTTPResponseWrapper{Response: &HTTPResponse{ Proto: h.Response.Proto, Status: h.Response.Status, ContentLength: int(h.Response.ContentLength), diff --git a/layers/ipv4.go b/layers/ipv4.go index 94c3b29..861d31b 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -83,7 +83,11 @@ func NewIPv4Packet(srcIP, dstIP netip.Addr, proto IPProto, payload []byte) (*IPv DstIP: dstIP, Payload: payload, } - ipPacket.HeaderChecksum, _ = CalculateIPv4Checksum(ipPacket.ToBytes()) + headerChecksum, err := CalculateIPv4Checksum(ipPacket.ToBytes()) + if err != nil { + return nil, fmt.Errorf("failed calculating checksum") + } + ipPacket.HeaderChecksum = headerChecksum return ipPacket, nil } @@ -159,6 +163,9 @@ func (p *IPv4Packet) UnmarshalBinary(data []byte) error { buf = append(buf, data...) versionIHL := buf[0] p.Version = versionIHL >> 4 + if p.Version != 4 { + return fmt.Errorf("unknown version") + } p.IHL = versionIHL & 15 dscpECN := buf[1] p.DSCP = dscpECN >> 2 @@ -172,7 +179,11 @@ func (p *IPv4Packet) UnmarshalBinary(data []byte) error { p.FragmentOffset = flagsOffset & (1<<13 - 1) p.TTL = buf[8] proto := IPProto(buf[9]) - p.Protocol = &IPv4Proto{Val: proto, Desc: protodesc(proto)} + protodesc := protodesc(proto) + if protodesc == "Unknown" { + return fmt.Errorf("unknown protocol") + } + p.Protocol = &IPv4Proto{Val: proto, Desc: protodesc} p.HeaderChecksum = binary.BigEndian.Uint16(buf[headerChecksumOffsetIPv4:12]) var ok bool p.SrcIP, ok = netip.AddrFromSlice(buf[12:16]) @@ -279,6 +290,10 @@ func (p *IPv4Packet) PseudoHeader() *IPv4PseudoHeader { return &IPv4PseudoHeader{SrcIP: p.SrcIP, DstIP: p.DstIP, Protocol: p.Protocol, TotalLength: uint16(len(p.Payload))} } +func (p *IPv4Packet) SetPayload(payload []byte) { + p.Payload = payload +} + type IPv4PseudoHeader struct { SrcIP netip.Addr DstIP netip.Addr diff --git a/layers/ipv6.go b/layers/ipv6.go index cd54ee4..a598e0b 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -84,11 +84,20 @@ func (p *IPv6Packet) Parse(data []byte) error { buf = append(buf, data...) versionTrafficFlow := binary.BigEndian.Uint32(buf[0:4]) p.Version = uint8(versionTrafficFlow >> 28) + if p.Version != 6 { + return fmt.Errorf("unknown version") + } p.TrafficClass = newTrafficiClass(uint8((versionTrafficFlow >> 20) & 0xFF)) + if p.TrafficClass.DSCPDesc == "Unknown" { + return fmt.Errorf("unknown DSCP") + } p.FlowLabel = versionTrafficFlow & (1<<20 - 1) p.PayloadLength = binary.BigEndian.Uint16(buf[4:6]) p.NextHeader = buf[6] p.NextHeaderDesc = p.nextHeader() + if p.NextHeaderDesc == "Unknown" { + return fmt.Errorf("unknown next header") + } p.HopLimit = buf[7] var ok bool p.SrcIP, ok = netip.AddrFromSlice(buf[8:24]) @@ -161,7 +170,7 @@ func (p *IPv6Packet) nextHeader() string { case 140: header = "Shim6 Protocol" default: - header = "" + header = "Unknown" } return header } diff --git a/layers/layers.go b/layers/layers.go index d4f6dad..8943af7 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -206,10 +206,7 @@ func ParseNextLayer(data []byte, src, dst *uint16) Layer { return next } } - if next = parseNextLayerFromBytes(buf); next != nil { - return next - } - return parseNextLayerFallback(buf) + return parseNextLayerFromBytes(buf) } func GetLayer(layer LayerName) Layer { diff --git a/layers/tcp.go b/layers/tcp.go index 374cb23..0df3a58 100644 --- a/layers/tcp.go +++ b/layers/tcp.go @@ -128,7 +128,7 @@ func (t *TCPSegment) Parse(data []byte) error { t.Checksum = binary.BigEndian.Uint16(buf[16:18]) t.UrgentPointer = binary.BigEndian.Uint16(buf[18:headerSizeTCP]) offset := t.DataOffset << 2 - if len(buf) < int(offset) { + if len(buf) < int(offset) || int(offset) < headerSizeTCP { return ErrSliceBounds } t.Options = buf[headerSizeTCP:offset] diff --git a/layers/tls.go b/layers/tls.go index 8e74bcd..cf110c0 100644 --- a/layers/tls.go +++ b/layers/tls.go @@ -86,11 +86,17 @@ type ServerName struct { } func (sn *ServerName) Parse(data []byte) error { + if len(data) < 9 { + return ErrSliceBounds + } sn.Type = binary.BigEndian.Uint16(data[0:2]) sn.Length = binary.BigEndian.Uint16(data[2:4]) sn.SNListLength = binary.BigEndian.Uint16(data[4:6]) sn.SNType = data[6] sn.SNNameLength = binary.BigEndian.Uint16(data[7:9]) + if int(9+sn.SNNameLength) > len(data) { + return ErrSliceBounds + } sn.SNName = string(data[9 : 9+sn.SNNameLength]) return nil } @@ -545,10 +551,13 @@ func (t *TLSMessage) Parse(data []byte) error { if len(buf) < headerSizeTLS { return ErrTLSTooShort } - for len(buf) > 0 { + for i := 0; len(buf) > 0; i++ { ctype := buf[0] ctdesc := ctdesc(ctype) if ctdesc == "Unknown" { + if i == 0 { + return fmt.Errorf("unknown content type") + } break } if len(buf) < 3 { @@ -557,6 +566,9 @@ func (t *TLSMessage) Parse(data []byte) error { ver := binary.BigEndian.Uint16(buf[1:3]) verdesc := verdesc(ver) if verdesc == "Unknown" { + if i == 0 { + return fmt.Errorf("unknown version") + } break } if len(buf) < headerSizeTLS { diff --git a/layers/udp.go b/layers/udp.go index d2ad1a5..bca6ad5 100644 --- a/layers/udp.go +++ b/layers/udp.go @@ -19,15 +19,8 @@ type UDPSegment struct { Payload []byte } -func NewUDPSegment(srcPort, dstPort uint16, payload []byte, pseudo *[]byte) (*UDPSegment, error) { +func NewUDPSegment(srcPort, dstPort uint16, payload []byte) (*UDPSegment, error) { udp := &UDPSegment{SrcPort: srcPort, DstPort: dstPort, UDPLength: uint16(headerSizeUDP + len(payload)), Payload: payload} - if pseudo != nil { - var err error - udp.Checksum, err = CalculateUDPChecksum(append(*pseudo, udp.ToBytes()...)) - if err != nil { - return nil, err - } - } return udp, nil } @@ -92,6 +85,15 @@ func (u *UDPSegment) NextLayer() Layer { func (u *UDPSegment) Name() LayerName { return LayerUDP } +func (u *UDPSegment) SetChecksum(pseudo []byte) error { + checksum, err := CalculateUDPChecksum(append(pseudo, u.ToBytes()...)) + if err != nil { + return err + } + u.Checksum = checksum + return nil +} + func CalculateUDPChecksum(data []byte) (uint16, error) { var sum uint16 udpLength := len(data) diff --git a/mshark.go b/mshark.go index 5b5af40..d54879b 100644 --- a/mshark.go +++ b/mshark.go @@ -11,14 +11,10 @@ import ( "strings" "time" - "github.com/mdlayher/packet" - "github.com/packetcap/go-pcap/filter" "github.com/shadowy-pycoder/mshark/layers" - "golang.org/x/net/bpf" + "github.com/shadowy-pycoder/mshark/network" ) -const unixEthPAll int = 0x03 - var colorMap = map[int]string{ // TODO (shadowy-pycoder): add colors from shadowy-pycoder/colors 0: "\033[37m", 1: "\033[36m", @@ -144,39 +140,15 @@ func (mw *Writer) WriteHeader(c *Config) error { // OpenLive opens a live capture based on the given configuration and writes // all captured packets to the given PacketWriters. func OpenLive(conf *Config, pw ...PacketWriter) error { - packetcfg := packet.Config{} - - // setting up filter - if conf.Expr != "" { - e := filter.NewExpression(conf.Expr) - f := e.Compile() - instructions, err := f.Compile() - if err != nil { - return fmt.Errorf("failed to compile filter into instructions: %v", err) - } - raw, err := bpf.Assemble(instructions) - if err != nil { - return fmt.Errorf("bpf assembly failed: %v", err) - } - packetcfg.Filter = raw - } - - // opening connection - c, err := packet.Listen(conf.Device, packet.Raw, unixEthPAll, &packetcfg) - if err != nil { - if errors.Is(err, os.ErrPermission) { - return fmt.Errorf("permission denied (try setting CAP_NET_RAW capability): %v", err) - } - return fmt.Errorf("failed to listen: %v", err) - } - + lc := &network.ListenConfig{Device: conf.Device, FilterExpr: conf.Expr} // setting promisc mode if conf.Device.Name != "any" { - if err := c.SetPromiscuous(conf.Promisc); err != nil { - return fmt.Errorf("unable to set promiscuous mode: %v", err) - } + lc.Promiscuous = &conf.Promisc + } + c, err := network.ListenPacket(lc) + if err != nil { + return err } - // timeout if conf.Timeout > 0 { if err := c.SetDeadline(time.Now().Add(conf.Timeout)); err != nil { diff --git a/network/network.go b/network/network.go index 6213112..2c5e6e7 100644 --- a/network/network.go +++ b/network/network.go @@ -3,6 +3,7 @@ package network import ( "bufio" + "errors" "fmt" "net" "net/netip" @@ -10,13 +11,69 @@ import ( "os/exec" "strings" "text/tabwriter" + + "github.com/mdlayher/packet" + "github.com/packetcap/go-pcap/filter" + "golang.org/x/net/bpf" ) +const ETH_P_ALL int = 0x03 + var ( BroadcastMAC = net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} LoopbackMAC = net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} ) +type ListenConfig struct { + Device *net.Interface // network interface which to bind to, if not specified default interface is used + Protocol int // network protocol, defaults to ETH_P_ALL + Promiscuous *bool // enable or disable promiscuous mode + FilterExpr string // packet filter expression like in tcpdump +} + +func ListenPacket(conf *ListenConfig) (*packet.Conn, error) { + packetcfg := packet.Config{} + // setting up filter + if conf.FilterExpr != "" { + e := filter.NewExpression(conf.FilterExpr) + f := e.Compile() + instructions, err := f.Compile() + if err != nil { + return nil, fmt.Errorf("failed to compile filter into instructions: %v", err) + } + raw, err := bpf.Assemble(instructions) + if err != nil { + return nil, fmt.Errorf("bpf assembly failed: %v", err) + } + packetcfg.Filter = raw + } + if conf.Device == nil { + var err error + conf.Device, err = GetDefaultInterface() + if err != nil { + return nil, err + } + } + if conf.Protocol == 0 { + conf.Protocol = ETH_P_ALL + } + // opening connection + c, err := packet.Listen(conf.Device, packet.Raw, conf.Protocol, &packetcfg) + if err != nil { + if errors.Is(err, os.ErrPermission) { + return nil, fmt.Errorf("permission denied (try setting CAP_NET_RAW capability): %v", err) + } + return nil, fmt.Errorf("failed to listen: %v", err) + } + // setting promisc mode + if conf.Promiscuous != nil { + if err := c.SetPromiscuous(*conf.Promiscuous); err != nil { + return nil, fmt.Errorf("unable to set promiscuous mode: %v", err) + } + } + return c, nil +} + // InterfaceByName returns the interface specified by name. func InterfaceByName(name string) (*net.Interface, error) { var ( From 0d51cf820db536f8bf61d7c57dbbfb221384f408 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Mon, 11 Aug 2025 11:35:00 +0300 Subject: [PATCH 13/17] added methods to arpspoofer --- arpspoof/arpspoof.go | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/arpspoof/arpspoof.go b/arpspoof/arpspoof.go index 8c14c46..2bba3bd 100644 --- a/arpspoof/arpspoof.go +++ b/arpspoof/arpspoof.go @@ -53,7 +53,7 @@ type ARPSpoofConfig struct { // NewARPSpoofConfig creates ARPSpoofConfig from a list of options separated by semicolon and logger. // -// Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true;interface eth0;gateway 192.168.1.1"`. +// Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true;interface eth0;gateway 192.168.1.1". // All fields in configuration string are optional. func NewARPSpoofConfig(s string, logger *zerolog.Logger) (*ARPSpoofConfig, error) { asc := &ARPSpoofConfig{Logger: logger} @@ -173,6 +173,30 @@ type ARPSpoofer struct { p *packet.Conn } +func (ar *ARPSpoofer) Interface() *net.Interface { + return ar.iface +} + +func (ar *ARPSpoofer) GatewayIP() netip.Addr { + return ar.gwIP +} + +func (ar *ARPSpoofer) GatewayMAC() net.HardwareAddr { + return ar.gwMAC +} + +func (ar *ARPSpoofer) HostIP() netip.Addr { + return ar.hostIP +} + +func (ar *ARPSpoofer) HostMAC() net.HardwareAddr { + return ar.hostMAC +} + +func (ar *ARPSpoofer) ARPTable() *ARPTable { + return ar.arpTable +} + func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { arpspoofer := &ARPSpoofer{} // determining interface From 9261160ba0d6f8cf8ecbc47563b0787369dd4cac Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Tue, 12 Aug 2025 10:03:39 +0300 Subject: [PATCH 14/17] added more checks to eliminate false positives while parsing packets --- layers/ftp.go | 14 ++++++++------ layers/icmp.go | 3 +++ layers/ipv4.go | 6 ++++++ layers/ipv6.go | 7 +++++-- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/layers/ftp.go b/layers/ftp.go index 63d6592..df7cd90 100644 --- a/layers/ftp.go +++ b/layers/ftp.go @@ -26,20 +26,22 @@ func (f *FTPMessage) Parse(data []byte) error { if !checkFTP(buf) { return fmt.Errorf("malformed ftp message") } - f.summary = nil - f.data = nil sp := bytes.Split(buf, crlf) lsp := len(sp) switch { case lsp > 2: - f.summary = bytes.Join(sp[:2], bspace) + f.summary = bytes.TrimSpace(bytes.Join(sp[:2], bspace)) sp[0] = joinBytes(dash, sp[0]) - f.data = bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf) + f.data = bytes.TrimSpace(bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf)) case lsp > 1: - f.summary = sp[0] + f.summary = bytes.TrimSpace(sp[0]) sp[0] = joinBytes(dash, sp[0]) - f.data = bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf) + f.data = bytes.TrimSpace(bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf)) default: + return fmt.Errorf("failed parsing FTP message") + } + if len(f.summary) == 0 || len(f.data) == 0 { + return fmt.Errorf("failed parsing FTP message") } return nil } diff --git a/layers/icmp.go b/layers/icmp.go index 495b8ea..83c0001 100644 --- a/layers/icmp.go +++ b/layers/icmp.go @@ -65,6 +65,9 @@ func (i *ICMPSegment) Parse(data []byte) error { return fmt.Errorf("minimum payload length for ICMP with type %d is %d bytes", i.Type, pLen) } i.TypeDesc, i.CodeDesc = i.typecode() + if i.TypeDesc == "Unknown" || i.CodeDesc == "Unknown" { + return fmt.Errorf("failed determining type or code") + } return nil } diff --git a/layers/ipv4.go b/layers/ipv4.go index 861d31b..4dffbf9 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -170,8 +170,14 @@ func (p *IPv4Packet) UnmarshalBinary(data []byte) error { dscpECN := buf[1] p.DSCP = dscpECN >> 2 p.DSCPDesc = dscpdesc(p.DSCP) + if p.DSCPDesc == "Unknown" { + return fmt.Errorf("unknown DSCP") + } p.ECN = dscpECN & 3 p.TotalLength = binary.BigEndian.Uint16(buf[2:4]) + if int(p.TotalLength) != len(buf) { + return fmt.Errorf("total length is not equal to actual packet size") + } p.Identification = binary.BigEndian.Uint16(buf[4:6]) flagsOffset := binary.BigEndian.Uint16(buf[6:8]) flags := uint8(flagsOffset >> 13) diff --git a/layers/ipv6.go b/layers/ipv6.go index a598e0b..4974693 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -15,7 +15,7 @@ type TrafficClass struct { ECN uint8 } -func newTrafficiClass(tc uint8) *TrafficClass { +func newTrafficClass(tc uint8) *TrafficClass { dscpbin := tc >> 2 return &TrafficClass{ Raw: tc, @@ -87,7 +87,7 @@ func (p *IPv6Packet) Parse(data []byte) error { if p.Version != 6 { return fmt.Errorf("unknown version") } - p.TrafficClass = newTrafficiClass(uint8((versionTrafficFlow >> 20) & 0xFF)) + p.TrafficClass = newTrafficClass(uint8((versionTrafficFlow >> 20) & 0xFF)) if p.TrafficClass.DSCPDesc == "Unknown" { return fmt.Errorf("unknown DSCP") } @@ -109,6 +109,9 @@ func (p *IPv6Packet) Parse(data []byte) error { return fmt.Errorf("malformed IPv6 address") } p.Payload = buf[headerSizeIPv6:] + if p.PayloadLength != 0 && int(p.PayloadLength) != len(p.Payload) { + return fmt.Errorf("payload length filed is not equal to actual payload size") + } return nil } From f0d48fc437b25630af7b8ac310ac782cf8f5cb99 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Tue, 12 Aug 2025 10:32:01 +0300 Subject: [PATCH 15/17] ipv4 identification randomization, changed ttl and flags when creating war ipv4 packets --- layers/ipv4.go | 19 ++++++++++--------- layers/layers.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/layers/ipv4.go b/layers/ipv4.go index 4dffbf9..a64d387 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -73,15 +73,16 @@ func NewIPv4Packet(srcIP, dstIP netip.Addr, proto IPProto, payload []byte) (*IPv return nil, fmt.Errorf("malformed IPv4 address") } ipPacket := &IPv4Packet{ - Version: 4, - IHL: 5, - TotalLength: uint16(headerSizeIPv4 + len(payload)), - Flags: NewIPv4Flags(0), - TTL: 64, - Protocol: &IPv4Proto{Val: proto, Desc: protodesc(proto)}, - SrcIP: srcIP, - DstIP: dstIP, - Payload: payload, + Version: 4, + IHL: 5, + TotalLength: uint16(headerSizeIPv4 + len(payload)), + Identification: MustGenerateRandomUint16NE(), + Flags: NewIPv4Flags(2), + TTL: 128, + Protocol: &IPv4Proto{Val: proto, Desc: protodesc(proto)}, + SrcIP: srcIP, + DstIP: dstIP, + Payload: payload, } headerChecksum, err := CalculateIPv4Checksum(ipPacket.ToBytes()) if err != nil { diff --git a/layers/layers.go b/layers/layers.go index 8943af7..b50948a 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -3,9 +3,12 @@ package layers import ( "bytes" + "crypto/rand" "encoding/binary" "fmt" "unsafe" + + "github.com/shadowy-pycoder/mshark/native" ) const maxLenSummary = 110 @@ -273,3 +276,45 @@ func isDigit(b byte) bool { func isUpper(b byte) bool { return b >= 'A' && b <= 'Z' } + +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + + return b, nil +} + +func GenerateRandomUint16LE() (uint16, error) { + b, err := GenerateRandomBytes(2) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint16(b), nil +} + +func GenerateRandomUint16BE() (uint16, error) { + b, err := GenerateRandomBytes(2) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint16(b), nil +} + +func GenerateRandomUint16NE() (uint16, error) { + b, err := GenerateRandomBytes(2) + if err != nil { + return 0, err + } + return native.Endian.Uint16(b), nil +} + +func MustGenerateRandomUint16NE() uint16 { + rn, err := GenerateRandomUint16NE() + if err != nil { + panic(err) + } + return rn +} From 1679c0ef240e6ce13674a6622e0567f11fd16b05 Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Tue, 12 Aug 2025 10:33:01 +0300 Subject: [PATCH 16/17] bumped to 0.0.14 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 8db832f..168cd96 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.13" +const Version string = "mshark v0.0.14" From 8cf5fcd24b3e02364df9cd7ba339aae133b34a2d Mon Sep 17 00:00:00 2001 From: shadowy-pycoder <35629483+shadowy-pycoder@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:59:49 +0300 Subject: [PATCH 17/17] updated network package to work on android --- arpspoof/arpspoof.go | 14 +++++++++++--- network/network.go | 31 +++++++++++++++++++++++++++++++ version.go | 2 +- 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/arpspoof/arpspoof.go b/arpspoof/arpspoof.go index 2bba3bd..58bce7a 100644 --- a/arpspoof/arpspoof.go +++ b/arpspoof/arpspoof.go @@ -200,9 +200,14 @@ func (ar *ARPSpoofer) ARPTable() *ARPTable { func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { arpspoofer := &ARPSpoofer{} // determining interface - iface, err := network.GetDefaultInterface() + var iface *net.Interface + var err error + iface, err = network.GetDefaultInterface() if err != nil { - return nil, err + iface, err = network.GetDefaultInterfaceFromRoute() + if err != nil { + return nil, err + } } if conf.Interface != "" { arpspoofer.iface, err = net.InterfaceByName(conf.Interface) @@ -236,7 +241,10 @@ func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { } else { gwIP, err = network.GetDefaultGatewayIPv4() if err != nil { - return nil, fmt.Errorf("failed fetching gateway ip: %w", err) + gwIP, err = network.GetDefaultGatewayIPv4FromRoute() + if err != nil { + return nil, fmt.Errorf("failed fetching gateway ip: %w", err) + } } } arpspoofer.gwIP = gwIP diff --git a/network/network.go b/network/network.go index 2c5e6e7..bdef9a4 100644 --- a/network/network.go +++ b/network/network.go @@ -139,6 +139,21 @@ func GetDefaultInterface() (*net.Interface, error) { return net.InterfaceByName(defaultInterface) } +func GetDefaultInterfaceFromRoute() (*net.Interface, error) { + cmd := exec.Command("sh", "-c", `ip -4 route get 8.8.8.8 | tr -d '\n'`) + routeRaw, err := cmd.Output() + if err != nil { + return nil, err + } + routeFields := strings.Fields(string(routeRaw)) + for i, f := range routeFields { + if f == "dev" && i+1 < len(routeFields) && routeFields[i+1] != "tun" { + return net.InterfaceByName(routeFields[i+1]) + } + } + return nil, fmt.Errorf("failed getting default interface from route") +} + func GetDefaultGatewayIPv4() (netip.Addr, error) { cmd := exec.Command("sh", "-c", `ip -4 route show 0.0.0.0/0 | awk '{print $3 " " $5}'`) ipdevRaw, err := cmd.Output() @@ -167,6 +182,22 @@ func GetDefaultGatewayIPv4() (netip.Addr, error) { return netip.Addr{}, fmt.Errorf("gateway IPv4 not found ") } +func GetDefaultGatewayIPv4FromRoute() (netip.Addr, error) { + cmd := exec.Command("sh", "-c", `ip -4 route get 8.8.8.8 | awk '{print $3}' | tr -d '\n'`) + ipstrRaw, err := cmd.Output() + if err != nil { + return netip.Addr{}, err + } + ip, err := netip.ParseAddr(string(ipstrRaw)) + if err != nil { + return netip.Addr{}, err + } + if !ip.IsValid() || !ip.Is4() { + return netip.Addr{}, fmt.Errorf("failed getting default gateway from route") + } + return ip, nil +} + func GetGatewayIPv4FromInterface(iface string) (netip.Addr, error) { cmd := exec.Command("sh", "-c", fmt.Sprintf("ip -4 route show dev %s", iface)) routes, err := cmd.Output() diff --git a/version.go b/version.go index 168cd96..3604c1f 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.14" +const Version string = "mshark v0.0.15"