diff --git a/arpspoof/arpspoof.go b/arpspoof/arpspoof.go index 7f53b8e..58bce7a 100644 --- a/arpspoof/arpspoof.go +++ b/arpspoof/arpspoof.go @@ -12,6 +12,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/malfunkt/iprange" @@ -23,7 +24,6 @@ import ( ) const ( - protocolARP = 0x0806 unixEthPAll = 0x03 ) @@ -32,6 +32,9 @@ var ( probeTargetsInterval = 60 * time.Second refreshARPTableInterval = 15 * time.Second arpSpoofTargetsInterval = 1 * time.Second + errARPSpoofConfig = fmt.Errorf( + `failed parsing arp options. Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true"`, + ) ) type Packet struct { @@ -48,6 +51,45 @@ type ARPSpoofConfig struct { Debug bool } +// NewARPSpoofConfig creates ARPSpoofConfig from a list of options separated by semicolon and logger. +// +// Example: "targets 10.0.0.1,10.0.0.5-10,192.168.1.*,192.168.10.0/24;fullduplex false;debug true;interface eth0;gateway 192.168.1.1". +// All fields in configuration string are optional. +func NewARPSpoofConfig(s string, logger *zerolog.Logger) (*ARPSpoofConfig, error) { + asc := &ARPSpoofConfig{Logger: logger} + for opt := range strings.SplitSeq(strings.ToLower(s), ";") { + keyval := strings.SplitN(strings.Trim(opt, " "), " ", 2) + if len(keyval) < 2 { + return nil, errARPSpoofConfig + } + key := keyval[0] + val := keyval[1] + switch key { + case "targets": + asc.Targets = val + case "interface": + asc.Interface = val + case "gateway": + gateway, err := netip.ParseAddr(val) + if err != nil { + return nil, err + } + asc.Gateway = &gateway + case "fullduplex": + if val == "true" { + asc.FullDuplex = true + } + case "debug": + if val == "true" { + asc.Debug = true + } + default: + return nil, errARPSpoofConfig + } + } + return asc, nil +} + type ARPTable struct { sync.RWMutex Ifname string @@ -115,27 +157,57 @@ func (at *ARPTable) Refresh() error { } type ARPSpoofer struct { - targets []netip.Addr - gwIP netip.Addr - gwMAC net.HardwareAddr - iface *net.Interface - hostIP netip.Addr - hostMAC net.HardwareAddr - fullduplex bool - arpTable *ARPTable - packets chan *Packet - logger *zerolog.Logger - quit chan bool - wg sync.WaitGroup - p *packet.Conn + targets []netip.Addr + gwIP netip.Addr + gwMAC net.HardwareAddr + iface *net.Interface + hostIP netip.Addr + hostMAC net.HardwareAddr + fullduplex bool + startingFlag atomic.Bool + arpTable *ARPTable + packets chan *Packet + logger *zerolog.Logger + quit chan bool + wg sync.WaitGroup + p *packet.Conn +} + +func (ar *ARPSpoofer) Interface() *net.Interface { + return ar.iface +} + +func (ar *ARPSpoofer) GatewayIP() netip.Addr { + return ar.gwIP +} + +func (ar *ARPSpoofer) GatewayMAC() net.HardwareAddr { + return ar.gwMAC +} + +func (ar *ARPSpoofer) HostIP() netip.Addr { + return ar.hostIP +} + +func (ar *ARPSpoofer) HostMAC() net.HardwareAddr { + return ar.hostMAC +} + +func (ar *ARPSpoofer) ARPTable() *ARPTable { + return ar.arpTable } func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { arpspoofer := &ARPSpoofer{} // determining interface - iface, err := network.GetDefaultInterface() + var iface *net.Interface + var err error + iface, err = network.GetDefaultInterface() if err != nil { - return nil, err + iface, err = network.GetDefaultInterfaceFromRoute() + if err != nil { + return nil, err + } } if conf.Interface != "" { arpspoofer.iface, err = net.InterfaceByName(conf.Interface) @@ -169,7 +241,10 @@ func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { } else { gwIP, err = network.GetDefaultGatewayIPv4() if err != nil { - return nil, fmt.Errorf("failed fetching gateway ip: %w", err) + gwIP, err = network.GetDefaultGatewayIPv4FromRoute() + if err != nil { + return nil, fmt.Errorf("failed fetching gateway ip: %w", err) + } } } arpspoofer.gwIP = gwIP @@ -261,6 +336,7 @@ func NewARPSpoofer(conf *ARPSpoofConfig) (*ARPSpoofer, error) { } func (ar *ARPSpoofer) Start() { + ar.startingFlag.Store(true) ar.logger.Info().Msg("[arp spoofer] Started") go ar.handlePackets() ar.logger.Debug().Msgf("[arp spoofer] Probing %d targets", len(ar.targets)) @@ -271,6 +347,7 @@ func (ar *ARPSpoofer) Start() { go ar.probeTargets() go ar.refreshARPTable() ar.wg.Add(1) + ar.startingFlag.Store(false) for { select { case <-ar.quit: @@ -284,6 +361,9 @@ func (ar *ARPSpoofer) Start() { } func (ar *ARPSpoofer) Stop() error { + for ar.startingFlag.Load() { + time.Sleep(50 * time.Millisecond) + } var err error ar.logger.Info().Msg("[arp spoofer] Stopping...") close(ar.quit) diff --git a/cmd/marpspoof/main.go b/cmd/marpspoof/main.go index f0ed072..9624247 100644 --- a/cmd/marpspoof/main.go +++ b/cmd/marpspoof/main.go @@ -38,7 +38,7 @@ func root(args []string) error { flags.BoolVar(&conf.Debug, "d", false, "Enable debug logging") nocolor := flags.Bool("nocolor", false, "Disable colored output") flags.BoolFunc("I", "Display list of interfaces and exit.", func(flagValue string) error { - if err := network.DisplayInterfaces(); err != nil { + if err := network.DisplayInterfaces(false); err != nil { fmt.Fprintf(os.Stderr, "%s: %v\n", app, err) os.Exit(2) } diff --git a/cmd/mshark/cli.go b/cmd/mshark/cli.go index ad989d0..09585e9 100644 --- a/cmd/mshark/cli.go +++ b/cmd/mshark/cli.go @@ -89,7 +89,7 @@ func root(args []string) error { packetBuffer := flags.Int("b", 8192, "The maximum size of packet queue.") flags.StringVar(&conf.Expr, "e", "", `BPF filter expression. Example: "ip proto tcp".`) flags.BoolFunc("D", "Display list of interfaces and exit.", func(flagValue string) error { - if err := network.DisplayInterfaces(); err != nil { + if err := network.DisplayInterfaces(true); err != nil { fmt.Fprintf(os.Stderr, "%s: %v\n", app, err) os.Exit(2) } diff --git a/layers/arp.go b/layers/arp.go index ec4e789..f69a07e 100644 --- a/layers/arp.go +++ b/layers/arp.go @@ -148,30 +148,41 @@ func (ap *ARPPacket) UnmarshalBinary(data []byte) error { if len(data) < headerSizeARP { return fmt.Errorf("minimum header size for ARP is %d bytes, got %d bytes", headerSizeARP, len(data)) } - ap.HardwareType = binary.BigEndian.Uint16(data[0:2]) - ap.ProtocolType = binary.BigEndian.Uint16(data[2:4]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + ap.HardwareType = binary.BigEndian.Uint16(buf[0:2]) + ap.ProtocolType = binary.BigEndian.Uint16(buf[2:4]) ap.ProtocolTypeDesc = ptypedesc(ap.ProtocolType) if ap.ProtocolTypeDesc == "Unknown" { return fmt.Errorf("unknown protocol type") } - ap.Hlen = data[4] - ap.Plen = data[5] - op := Operation(binary.BigEndian.Uint16(data[6:8])) + ap.Hlen = buf[4] + ap.Plen = buf[5] + op := Operation(binary.BigEndian.Uint16(buf[6:8])) opdesc := opdesc(op) if opdesc == "Unknown" { return fmt.Errorf("unknown operation") } ap.Op = &ARPOperation{Val: op, Desc: opdesc} hoffset := 8 + ap.Hlen - ap.SenderMAC = net.HardwareAddr(data[8:hoffset]) + ap.SenderMAC = net.HardwareAddr(buf[8:hoffset]) poffset := hoffset + ap.Plen + if int(poffset) > len(buf) { + return ErrSliceBounds + } var ok bool - ap.SenderIP, ok = netip.AddrFromSlice(data[hoffset:poffset]) + ap.SenderIP, ok = netip.AddrFromSlice(buf[hoffset:poffset]) if !ok { return fmt.Errorf("failed parsing sender IP address") } - ap.TargetMAC = net.HardwareAddr(data[poffset : poffset+ap.Hlen]) - ap.TargetIP, ok = netip.AddrFromSlice(data[poffset+ap.Hlen : poffset+ap.Hlen+ap.Plen]) + if int(poffset+ap.Hlen) > len(buf) { + return ErrSliceBounds + } + ap.TargetMAC = net.HardwareAddr(buf[poffset : poffset+ap.Hlen]) + if int(poffset+ap.Hlen+ap.Plen) > len(buf) { + return ErrSliceBounds + } + ap.TargetIP, ok = netip.AddrFromSlice(buf[poffset+ap.Hlen : poffset+ap.Hlen+ap.Plen]) if !ok { return fmt.Errorf("failed parsing target IP address") } @@ -185,7 +196,8 @@ func (ap *ARPPacket) Parse(data []byte) error { return ap.UnmarshalBinary(data) } -func (ap *ARPPacket) NextLayer() (layer string, payload []byte) { return } +func (ap *ARPPacket) NextLayer() Layer { return nil } +func (ap *ARPPacket) Name() LayerName { return LayerARP } func ptypedesc(pt uint16) string { var proto string diff --git a/layers/dns.go b/layers/dns.go index 2980f5a..2526629 100644 --- a/layers/dns.go +++ b/layers/dns.go @@ -2,6 +2,8 @@ package layers import ( "encoding/binary" + "encoding/hex" + "encoding/json" "fmt" "net/netip" "strings" @@ -10,20 +12,20 @@ import ( const headerSizeDNS = 12 type DNSFlags struct { - Raw uint16 - QR uint8 // Indicates if the message is a query (0) or a reply (1). - QRDesc string // Query (0) or Reply (1) - OPCode uint8 // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5 - OPCodeDesc string - AA uint8 // Authoritative Answer, in a response, indicates if the DNS server is authoritative for the queried hostname. - TC uint8 // TrunCation, indicates that this message was truncated due to excessive length. - RD uint8 // Recursion Desired, indicates if the client means a recursive query. - RA uint8 // Recursion Available, in a response, indicates if the replying DNS server supports recursion. - Z uint8 // Zero, reserved for future use. - AU uint8 // Indicates if answer/authority portion was authenticated by the server. - NA uint8 // Indicates if non-authenticated data is accepatable. - RCode uint8 // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 - RCodeDesc string + Raw uint16 `json:"raw"` + QR uint8 `json:"qr"` // Indicates if the message is a query (0) or a reply (1). + QRDesc string `json:"qrdesc"` // Query (0) or Reply (1) + OPCode uint8 `json:"opcode"` // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-5 + OPCodeDesc string `json:"opcodedesc"` + AA uint8 `json:"aa"` // Authoritative Answer, in a response, indicates if the DNS server is authoritative for the queried hostname. + TC uint8 `json:"tc"` // TrunCation, indicates that this message was truncated due to excessive length. + RD uint8 `json:"rd"` // Recursion Desired, indicates if the client means a recursive query. + RA uint8 `json:"ra"` // Recursion Available, in a response, indicates if the replying DNS server supports recursion. + Z uint8 `json:"z"` // Zero, reserved for future use. + AU uint8 `json:"au"` // Indicates if answer/authority portion was authenticated by the server. + NA uint8 `json:"na"` // Indicates if non-authenticated data is accepatable. + RCode uint8 `json:"rcode"` // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 + RCodeDesc string `json:"rcodedesc"` } func (df *DNSFlags) String() string { @@ -170,16 +172,16 @@ func rcdesc(rcode uint8) string { } type DNSMessage struct { - TransactionID uint16 // Used for matching response to queries. - Flags *DNSFlags // Flags specify the requested operation and a response code. - QDCount uint16 // Count of entries in the queries section. - ANCount uint16 // Count of entries in the answers section. - NSCount uint16 // Count of entries in the authority section. - ARCount uint16 // Count of entries in the additional section. - Questions []*QueryEntry - AnswerRRs []*ResourceRecord - AuthorityRRs []*ResourceRecord - AdditionalRRs []*ResourceRecord + TransactionID uint16 `json:"transaction-id"` // Used for matching response to queries. + Flags *DNSFlags `json:"flags,omitempty"` // Flags specify the requested operation and a response code. + QDCount uint16 `json:"questions-count"` // Count of entries in the queries section. + ANCount uint16 `json:"answer-rrs-count"` // Count of entries in the answers section. + NSCount uint16 `json:"authority-rrs-count"` // Count of entries in the authority section. + ARCount uint16 `json:"additional-rrs-count"` // Count of entries in the additional section. + Questions []*QueryEntry `json:"questions,omitempty"` + AnswerRRs []*ResourceRecord `json:"answers,omitempty"` + AuthorityRRs []*ResourceRecord `json:"authoritative-nameservers,omitempty"` + AdditionalRRs []*ResourceRecord `json:"additional-records,omitempty"` } func (d *DNSMessage) String() string { @@ -206,7 +208,7 @@ func (d *DNSMessage) String() string { func (d *DNSMessage) Summary() string { var sb strings.Builder - sb.WriteString(fmt.Sprintf("DNS Message: %s %s %#04x ", d.Flags.OPCodeDesc, d.Flags.QRDesc, d.TransactionID)) + sb.WriteString(fmt.Sprintf("DNS Message: %s (%s) %#04x ", d.Flags.OPCodeDesc, d.Flags.QRDesc, d.TransactionID)) for _, rec := range d.Questions { sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) if sb.Len() > maxLenSummary { @@ -214,19 +216,19 @@ func (d *DNSMessage) Summary() string { } } for _, rec := range d.AnswerRRs { - sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) + sb.WriteString(rec.Summary()) if sb.Len() > maxLenSummary { goto result } } for _, rec := range d.AuthorityRRs { - sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) + sb.WriteString(rec.Summary()) if sb.Len() > maxLenSummary { goto result } } for _, rec := range d.AdditionalRRs { - sb.WriteString(fmt.Sprintf("%s %s ", rec.Type.Name, rec.Name)) + sb.WriteString(rec.Summary()) if sb.Len() > maxLenSummary { goto result } @@ -236,39 +238,56 @@ result: return sb.String()[:maxLenSummary] + string(ellipsis) } -// Parse parses the given byte data into a DNSMessage struct. -func (d *DNSMessage) Parse(data []byte) error { +func (d *DNSMessage) UnmarshalBinary(data []byte) error { if len(data) < headerSizeDNS { return fmt.Errorf("minimum header size for DNS is %d bytes, got %d bytes", headerSizeDNS, len(data)) } - d.TransactionID = binary.BigEndian.Uint16(data[0:2]) - d.Flags = newDNSFlags(binary.BigEndian.Uint16(data[2:4])) - d.QDCount = binary.BigEndian.Uint16(data[4:6]) - d.ANCount = binary.BigEndian.Uint16(data[6:8]) - d.NSCount = binary.BigEndian.Uint16(data[8:10]) - d.ARCount = binary.BigEndian.Uint16(data[10:headerSizeDNS]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + d.TransactionID = binary.BigEndian.Uint16(buf[0:2]) + d.Flags = newDNSFlags(binary.BigEndian.Uint16(buf[2:4])) + d.QDCount = binary.BigEndian.Uint16(buf[4:6]) + d.ANCount = binary.BigEndian.Uint16(buf[6:8]) + d.NSCount = binary.BigEndian.Uint16(buf[8:10]) + d.ARCount = binary.BigEndian.Uint16(buf[10:headerSizeDNS]) var tail []byte - payload := data[headerSizeDNS:] - d.Questions = nil - d.AnswerRRs = nil - d.AuthorityRRs = nil - d.AdditionalRRs = nil + var err error + payload := buf[headerSizeDNS:] if d.QDCount > 0 { - d.Questions, tail = parseQueries(payload, payload, d.QDCount) + d.Questions, tail, err = parseQueries(payload, payload, d.QDCount) + if err != nil { + return fmt.Errorf("failed parsing queries: %v", err) + } } if d.ANCount > 0 { - d.AnswerRRs, tail = parseResourceRecords(payload, tail, d.ANCount) + d.AnswerRRs, tail, err = parseResourceRecords(payload, tail, d.ANCount) + if err != nil { + return fmt.Errorf("failed parsing answers: %v", err) + } } if d.NSCount > 0 { - d.AuthorityRRs, tail = parseResourceRecords(payload, tail, d.NSCount) + d.AuthorityRRs, tail, err = parseResourceRecords(payload, tail, d.NSCount) + if err != nil { + return fmt.Errorf("failed parsing authority records: %v", err) + } } if d.ARCount > 0 { - d.AdditionalRRs, _ = parseResourceRecords(payload, tail, d.ARCount) + d.AdditionalRRs, _, err = parseResourceRecords(payload, tail, d.ARCount) + if err != nil { + return fmt.Errorf("failed parsing additional records: %v", err) + } } return nil } -func (d *DNSMessage) NextLayer() (layer string, payload []byte) { return } +// Parse parses the given byte data into a DNSMessage struct. +func (d *DNSMessage) Parse(data []byte) error { + return d.UnmarshalBinary(data) +} + +func (d *DNSMessage) NextLayer() Layer { return nil } + +func (d *DNSMessage) Name() LayerName { return LayerDNS } func (d *DNSMessage) printRecords() string { var sb strings.Builder @@ -293,15 +312,31 @@ func (d *DNSMessage) printRecords() string { if d.ARCount > 0 { sb.WriteString("- Additional records:\n") for _, rec := range d.AdditionalRRs { - sb.WriteString(rec.String()) + sb.WriteString(strings.TrimSuffix(rec.String(), "\n")) } } return sb.String() } +type dnsMessageAlias DNSMessage + +type dnsQueryWrapper struct { + Query *dnsMessageAlias `json:"dns_query"` +} +type dnsReplyWrapper struct { + Reply *dnsMessageAlias `json:"dns_reply"` +} + +func (d *DNSMessage) MarshalJSON() ([]byte, error) { + if d.Flags.QR == 0 { + return json.Marshal(&dnsQueryWrapper{Query: (*dnsMessageAlias)(d)}) + } + return json.Marshal(&dnsReplyWrapper{Reply: (*dnsMessageAlias)(d)}) +} + type RecordClass struct { - Name string - Val uint16 + Name string `json:"name"` + Val uint16 `json:"val"` } func (c *RecordClass) String() string { @@ -331,8 +366,8 @@ func className(cls uint16) string { } type RecordType struct { - Name string - Val uint16 + Name string `json:"name"` + Val uint16 `json:"val"` } func (rt *RecordType) String() string { @@ -371,12 +406,12 @@ func typeName(typ uint16) string { } type ResourceRecord struct { - Name string // Name of the node to which this record pertains. - Type *RecordType // Type of RR in numeric form. - Class *RecordClass // Class code. - TTL uint32 // Count of seconds that the RR stays valid. - RDLength uint16 // Length of RData field (specified in octets). - RData fmt.Stringer // Additional RR-specific data. + Name string `json:"name"` // Name of the node to which this record pertains. + Type *RecordType `json:"record-type"` // Type of RR in numeric form. + Class *RecordClass `json:"record-class"` // Class code. + TTL uint32 `json:"ttl"` // Count of seconds that the RR stays valid. + RDLength uint16 `json:"rdata-length"` // Length of RData field (specified in octets). + RData fmt.Stringer `json:"rdata"` // Additional RR-specific data. } func (rt *ResourceRecord) String() string { @@ -410,10 +445,32 @@ func (rt *ResourceRecord) String() string { return record } +func (rt *ResourceRecord) Summary() string { + var summary string + switch rd := rt.RData.(type) { + case *RDataA: + case *RDataAAAA: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.Address) + case *RDataNS: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.NsdName) + case *RDataCNAME: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.CName) + case *RDataSOA: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.PrimaryNS) + case *RDataMX: + summary = fmt.Sprintf("%s %d %s ", rt.Type.Name, rd.Preference, rd.Exchange) + case *RDataTXT: + summary = fmt.Sprintf("%s %s ", rt.Type.Name, rd.TxtData) + default: + summary = fmt.Sprintf("%s ", rt.Type.Name) + } + return summary +} + type QueryEntry struct { - Name string // Name of the node to which this record pertains. - Type *RecordType // Type of RR in numeric form. - Class *RecordClass // Class code. + Name string `json:"name"` // Name of the node to which this record pertains. + Type *RecordType `json:"record-type"` // Type of RR in numeric form. + Class *RecordClass `json:"record-class"` // Class code. } func (qe *QueryEntry) String() string { @@ -425,7 +482,7 @@ func (qe *QueryEntry) String() string { } type RDataA struct { - Address netip.Addr + Address netip.Addr `json:"address"` } func (d *RDataA) String() string { @@ -433,7 +490,7 @@ func (d *RDataA) String() string { } type RDataNS struct { - NsdName string + NsdName string `json:"ns"` } func (d *RDataNS) String() string { @@ -441,7 +498,7 @@ func (d *RDataNS) String() string { } type RDataCNAME struct { - CName string + CName string `json:"cname"` } func (d *RDataCNAME) String() string { @@ -449,17 +506,17 @@ func (d *RDataCNAME) String() string { } type RDataSOA struct { - PrimaryNS string - RespAuthorityMailbox string - SerialNumber uint32 - RefreshInterval uint32 - RetryInterval uint32 - ExpireLimit uint32 - MinimumTTL uint32 + PrimaryNS string `json:"primary-nameserver"` + RespAuthorityMailbox string `json:"responsible-authority-mailbox"` + SerialNumber uint32 `json:"serial-number"` + RefreshInterval uint32 `json:"refresh-interval"` + RetryInterval uint32 `json:"retry-interval"` + ExpireLimit uint32 `json:"expire-limit"` + MinimumTTL uint32 `json:"minimum-ttl"` } func (d *RDataSOA) String() string { - return fmt.Sprintf(`Primary name server: %s + return fmt.Sprintf(`Primary name server: %s - Responsible authority's mailbox: %s - Serial number: %d - Refresh interval: %d @@ -476,8 +533,8 @@ func (d *RDataSOA) String() string { } type RDataMX struct { - Preference uint16 - Exchange string + Preference uint16 `json:"preference"` + Exchange string `json:"exchange"` } func (d *RDataMX) String() string { @@ -485,7 +542,7 @@ func (d *RDataMX) String() string { } type RDataTXT struct { - TxtData string + TxtData string `json:"txt-data"` } func (d *RDataTXT) String() string { @@ -493,24 +550,23 @@ func (d *RDataTXT) String() string { } type RDataAAAA struct { - Address netip.Addr + Address netip.Addr `json:"address"` } func (d *RDataAAAA) String() string { return fmt.Sprintf("Address: %s", d.Address) - } type RDataOPT struct { - UDPPayloadSize uint16 - HigherBitsExtRCode uint8 - EDNSVer uint8 - Z uint16 - DataLen uint16 + UDPPayloadSize uint16 `json:"udp-payload-size"` + HigherBitsExtRCode uint8 `json:"higer-bits-in-extended-rcode"` + EDNSVer uint8 `json:"edns0-version"` + Z uint16 `json:"z"` + DataLen uint16 `json:"data-length"` } func (d *RDataOPT) String() string { - return fmt.Sprintf(`UDP payload size: %d + return fmt.Sprintf(`UDP payload size: %d - Higher bits in extended RCODE: %#02x - EDNS0 version: %d - Z: %d @@ -523,16 +579,111 @@ func (d *RDataOPT) String() string { d.DataLen) } +type SvcParamKey struct { + Val uint16 `json:"val"` + Desc string `json:"desc"` +} + +// https://www.iana.org/assignments/dns-svcb/dns-svcb.xhtml +func svcparamkeydesc(key uint16) string { + var svcdesc string + switch key { + case 0: + svcdesc = "mandatory" + case 1: + svcdesc = "alpn" + case 2: + svcdesc = "no-default-alpn" + case 3: + svcdesc = "port" + case 4: + svcdesc = "ipv4hint" + case 5: + svcdesc = "ech" + case 6: + svcdesc = "ipv6hint" + case 7: + svcdesc = "dohpath" + case 8: + svcdesc = "ohttp" + case 9: + svcdesc = "tls-supported-groups" + default: + svcdesc = "Unknown" + } + return svcdesc +} + +func newSvcParamKey(key uint16) *SvcParamKey { + return &SvcParamKey{Val: key, Desc: svcparamkeydesc(key)} +} + +func (spk *SvcParamKey) String() string { + return fmt.Sprintf("%s (%d)", spk.Desc, spk.Val) +} + +type SvcParam struct { + Key *SvcParamKey `json:"svc-param-key"` + Length uint16 `json:"svc-param-value-length"` + Value []byte `json:"svc-param-value"` // TODO: add proper parsing +} + +func newSvcParam(data []byte) (*SvcParam, []byte, error) { + if len(data) < 4 { + return nil, nil, ErrSliceBounds + } + key := newSvcParamKey(binary.BigEndian.Uint16(data[0:2])) + length := binary.BigEndian.Uint16(data[2:4]) + offset := 4 + length + if offset > uint16(len(data)) { + return nil, nil, ErrSliceBounds + } + value := data[4:offset] + return &SvcParam{Key: key, Length: length, Value: value}, data[offset:], nil +} + +func (sp *SvcParam) String() string { + return fmt.Sprintf(` - SvcParamKey: %s + - SvcParamValue length: %d + - SvcParamValue: %s +`, + sp.Key, + sp.Length, + hex.EncodeToString(sp.Value), + ) +} + type RDataHTTPS struct { - Data string // TODO: add proper parsing + SvcPriority uint16 `json:"svc-priority"` + Length int `json:"length"` + TargetName string `json:"target-name"` + SvcParams []*SvcParam `json:"svc-params"` +} + +func (d *RDataHTTPS) printSvcParams() string { + var sb strings.Builder + for _, p := range d.SvcParams { + if p == nil { + continue + } + sb.WriteString(p.String()) + } + return strings.TrimRight(sb.String(), "\n") } func (d *RDataHTTPS) String() string { - return d.Data + return fmt.Sprintf(`SvcPriority: %d + - TargetName: %s + - SvcParams: +%s`, + d.SvcPriority, + d.TargetName, + d.printSvcParams(), + ) } type RDataUnknown struct { - Data string + Data string `json:"data"` } func (d *RDataUnknown) String() string { @@ -542,15 +693,24 @@ func (d *RDataUnknown) String() string { // extractDomain extracts the DNS domain name from the given payload and tail. // // The domain name is parsed according to RFC 1035 section 4.1. -func extractDomain(payload, tail []byte) (string, []byte) { +func extractDomain(payload, tail []byte) (string, []byte, error) { // see https://brunoscheufler.com/blog/2024-05-12-building-a-dns-message-parser#domain-names var domainName string - for { + for len(tail) > 0 { blen := tail[0] if blen>>6 == 0b11 { + if len(tail) < 2 { + return "", nil, ErrSliceBounds + } // compressed message offset is 14 bits according to RFC 1035 section 4.1.4 offset := binary.BigEndian.Uint16(tail[0:2])&(1<<14-1) - headerSizeDNS - part, _ := extractDomain(payload, payload[offset:]) // TODO: iterative approach + if offset > uint16(len(payload)) { + return "", nil, ErrSliceBounds + } + part, _, err := extractDomain(payload, payload[offset:]) // TODO: iterative approach + if err != nil { + return "", nil, err + } domainName += part tail = tail[2:] break @@ -559,17 +719,27 @@ func extractDomain(payload, tail []byte) (string, []byte) { if blen == 0 { break } + if int(blen) > len(tail) { + return "", nil, ErrSliceBounds + } domainName += bytesToStr(tail[0:blen]) domainName += "." tail = tail[blen:] } - return strings.TrimRight(domainName, "."), tail + return strings.TrimRight(domainName, "."), tail, nil } -func parseQuery(payload, tail []byte) (*QueryEntry, []byte) { +func parseQuery(payload, tail []byte) (*QueryEntry, []byte, error) { var domain string - domain, tail = extractDomain(payload, tail) + var err error + domain, tail, err = extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } + if len(tail) < 4 { + return nil, nil, ErrSliceBounds + } typ := binary.BigEndian.Uint16(tail[0:2]) cls := binary.BigEndian.Uint16(tail[2:4]) tail = tail[4:] @@ -577,43 +747,69 @@ func parseQuery(payload, tail []byte) (*QueryEntry, []byte) { Name: domain, Type: newRecordType(typ), Class: newRecordClass(cls), - }, tail + }, tail, nil } -func parseQueries(payload, tail []byte, numRecords uint16) ([]*QueryEntry, []byte) { +func parseQueries(payload, tail []byte, numRecords uint16) ([]*QueryEntry, []byte, error) { queries := make([]*QueryEntry, numRecords) + var err error for i := range queries { - queries[i], tail = parseQuery(payload, tail) + queries[i], tail, err = parseQuery(payload, tail) + if err != nil { + return nil, nil, err + } } - return queries, tail + return queries, tail, nil } // https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-4 -func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte) { +func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte, error) { var rdata fmt.Stringer + if rdl > len(tail) { + return nil, nil, ErrSliceBounds + } switch typ { case 1: - addr, _ := netip.AddrFromSlice(tail[0:rdl]) + addr, ok := netip.AddrFromSlice(tail[0:rdl]) + if !ok { + return nil, nil, ErrParsingAddress + } rdata = &RDataA{Address: addr} case 2: - domain, _ := extractDomain(payload, tail) + domain, _, err := extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } rdata = &RDataNS{NsdName: domain} case 5: - domain, _ := extractDomain(payload, tail) + domain, _, err := extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } rdata = &RDataCNAME{CName: domain} case 6: var ( primary string mailbox string + err error ) ttail := tail - primary, ttail = extractDomain(payload, ttail) - mailbox, ttail = extractDomain(payload, ttail) + primary, ttail, err = extractDomain(payload, ttail) + if err != nil { + return nil, nil, err + } + mailbox, ttail, err = extractDomain(payload, ttail) + if err != nil { + return nil, nil, err + } + if len(ttail) < 20 { + return nil, nil, ErrSliceBounds + } serial := binary.BigEndian.Uint32(ttail[0:4]) refresh := binary.BigEndian.Uint32(ttail[4:8]) retry := binary.BigEndian.Uint32(ttail[8:12]) expire := binary.BigEndian.Uint32(ttail[12:16]) - min := binary.BigEndian.Uint32(ttail[16:20]) + minttl := binary.BigEndian.Uint32(ttail[16:20]) rdata = &RDataSOA{ PrimaryNS: primary, RespAuthorityMailbox: mailbox, @@ -621,11 +817,14 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte RefreshInterval: refresh, RetryInterval: retry, ExpireLimit: expire, - MinimumTTL: min, + MinimumTTL: minttl, } case 15: preference := binary.BigEndian.Uint16(tail[0:2]) - domain, _ := extractDomain(payload, tail[2:rdl]) + domain, _, err := extractDomain(payload, tail[2:rdl]) + if err != nil { + return nil, nil, err + } rdata = &RDataMX{ Preference: preference, Exchange: domain, @@ -633,9 +832,15 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte case 16: rdata = &RDataTXT{TxtData: string(tail[:rdl])} case 28: - addr, _ := netip.AddrFromSlice(tail[0:rdl]) + addr, ok := netip.AddrFromSlice(tail[0:rdl]) + if !ok { + return nil, nil, ErrParsingAddress + } rdata = &RDataAAAA{Address: addr} case 41: + if len(tail) < 8 { + return nil, nil, ErrSliceBounds + } ups := binary.BigEndian.Uint16(tail[0:2]) hb := tail[2] ednsv := tail[3] @@ -649,35 +854,84 @@ func parseRData(payload, tail []byte, typ uint16, rdl int) (fmt.Stringer, []byte DataLen: uint16(rdl), } case 65: - rdata = &RDataHTTPS{Data: string(tail[:rdl])} + if len(tail) < 3 { + return nil, nil, ErrSliceBounds + } + priority := binary.BigEndian.Uint16(tail[0:2]) + nameLength := tail[2] + var target string + var err error + ttail := tail[:rdl] + if nameLength == 0 { + target = "Root" + ttail = ttail[3:] + } else { + target, ttail, err = extractDomain(payload, ttail) + if err != nil { + return nil, nil, err + } + } + svcParams := make([]*SvcParam, 10) + var svcParam *SvcParam + for len(ttail) > 0 { + svcParam, ttail, err = newSvcParam(ttail) + if err != nil { + return nil, nil, err + } + if svcParam.Key.Desc == "Unknown" { + continue + } + svcParams[svcParam.Key.Val] = svcParam + } + rdata = &RDataHTTPS{SvcPriority: priority, Length: int(nameLength), TargetName: target, SvcParams: svcParams} default: rdata = &RDataUnknown{Data: string(tail[:rdl])} } - return rdata, tail[rdl:] + if rdl > len(tail) { + return nil, nil, ErrSliceBounds + } + return rdata, tail[rdl:], nil } -func parseRoot(payload, tail []byte) (*ResourceRecord, []byte) { +func parseRoot(payload, tail []byte) (*ResourceRecord, []byte, error) { + if len(tail) < 10 { + return nil, nil, ErrSliceBounds + } typ := binary.BigEndian.Uint16(tail[0:2]) rdl := int(binary.BigEndian.Uint16(tail[8:10])) var rdata fmt.Stringer - rdata, tail = parseRData(payload, tail[2:], typ, rdl) + var err error + rdata, tail, err = parseRData(payload, tail[2:], typ, rdl) + if err != nil { + return nil, nil, err + } return &ResourceRecord{ Name: "Root", Type: newRecordType(typ), Class: &RecordClass{}, RData: rdata, - }, tail + }, tail, nil } -func parseResourceRecord(payload, tail []byte) (*ResourceRecord, []byte) { +func parseResourceRecord(payload, tail []byte) (*ResourceRecord, []byte, error) { var domain string - domain, tail = extractDomain(payload, tail) + var err error + domain, tail, err = extractDomain(payload, tail) + if err != nil { + return nil, nil, err + } + if len(tail) < 10 { + return nil, nil, ErrSliceBounds + } typ := binary.BigEndian.Uint16(tail[0:2]) cls := binary.BigEndian.Uint16(tail[2:4]) ttl := binary.BigEndian.Uint32(tail[4:8]) rdl := binary.BigEndian.Uint16(tail[8:10]) var rdata fmt.Stringer - rdata, tail = parseRData(payload, tail[10:], typ, int(rdl)) + rdata, tail, err = parseRData(payload, tail[10:], typ, int(rdl)) + if err != nil { + return nil, nil, err + } return &ResourceRecord{ Name: domain, Type: newRecordType(typ), @@ -685,17 +939,27 @@ func parseResourceRecord(payload, tail []byte) (*ResourceRecord, []byte) { TTL: ttl, RDLength: rdl, RData: rdata, - }, tail + }, tail, nil } -func parseResourceRecords(payload, tail []byte, numRecords uint16) ([]*ResourceRecord, []byte) { +func parseResourceRecords(payload, tail []byte, numRecords uint16) ([]*ResourceRecord, []byte, error) { + if len(tail) < 1 { + return nil, nil, ErrSliceBounds + } records := make([]*ResourceRecord, numRecords) + var err error for i := range records { if tail[0] != 0 { - records[i], tail = parseResourceRecord(payload, tail) + records[i], tail, err = parseResourceRecord(payload, tail) + if err != nil { + return nil, nil, err + } } else { - records[i], tail = parseRoot(payload, tail[1:]) + records[i], tail, err = parseRoot(payload, tail[1:]) + if err != nil { + return nil, nil, err + } } } - return records, tail + return records, tail, nil } diff --git a/layers/ethernet.go b/layers/ethernet.go index da7b412..fe9b147 100644 --- a/layers/ethernet.go +++ b/layers/ethernet.go @@ -92,12 +92,17 @@ func (ef *EthernetFrame) UnmarshalBinary(data []byte) error { if len(data) < headerSizeEthernet { return fmt.Errorf("did not read a complete Ethernet frame, only %d bytes read", len(data)) } - ef.DstMAC = net.HardwareAddr(data[0:6]) - ef.SrcMAC = net.HardwareAddr(data[6:12]) - et := EtherType(binary.BigEndian.Uint16(data[12:14])) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + ef.DstMAC = net.HardwareAddr(buf[0:6]) + ef.SrcMAC = net.HardwareAddr(buf[6:12]) + et := EtherType(binary.BigEndian.Uint16(buf[12:14])) etdesc := ethertypedesc(et) + if etdesc == "Unknown" { + return fmt.Errorf("failed determining Ethernet type") + } ef.EtherType = &EthernetType{Val: et, Desc: etdesc} - ef.Payload = data[headerSizeEthernet:] + ef.Payload = buf[headerSizeEthernet:] ef.DstVendor = oui.VendorWithMAC(ef.DstMAC) ef.SrcVendor = oui.VendorWithMAC(ef.SrcMAC) return nil @@ -108,11 +113,17 @@ func (ef *EthernetFrame) Parse(data []byte) error { return ef.UnmarshalBinary(data) } -// NextLayer returns the name and payload of the next layer protocol based on the EtherType field of the EthernetFrame. -func (ef *EthernetFrame) NextLayer() (string, []byte) { - return ef.EtherType.Desc, ef.Payload +func (ef *EthernetFrame) NextLayer() Layer { + if next := GetLayer(LayerName(ef.EtherType.Desc)); next != nil { + if err := next.Parse(ef.Payload); err == nil { + return next + } + } + return ParseNextLayer(ef.Payload, nil, nil) } +func (ef *EthernetFrame) Name() LayerName { return LayerETH } + func ethertypedesc(et EtherType) string { var etdesc string switch et { @@ -123,7 +134,7 @@ func ethertypedesc(et EtherType) string { case EtherTypeIPv6: etdesc = "IPv6" default: - etdesc = "" + etdesc = "Unknown" } return etdesc } diff --git a/layers/ftp.go b/layers/ftp.go index f16b288..df7cd90 100644 --- a/layers/ftp.go +++ b/layers/ftp.go @@ -21,22 +21,36 @@ func (f *FTPMessage) Summary() string { } func (f *FTPMessage) Parse(data []byte) error { - f.summary = nil - f.data = nil - sp := bytes.Split(data, crlf) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + if !checkFTP(buf) { + return fmt.Errorf("malformed ftp message") + } + sp := bytes.Split(buf, crlf) lsp := len(sp) switch { case lsp > 2: - f.summary = bytes.Join(sp[:2], bspace) + f.summary = bytes.TrimSpace(bytes.Join(sp[:2], bspace)) sp[0] = joinBytes(dash, sp[0]) - f.data = bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf) + f.data = bytes.TrimSpace(bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf)) case lsp > 1: - f.summary = sp[0] + f.summary = bytes.TrimSpace(sp[0]) sp[0] = joinBytes(dash, sp[0]) - f.data = bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf) + f.data = bytes.TrimSpace(bytes.TrimSuffix(bytes.TrimSuffix(bytes.Join(sp, lfd), dash), lf)) default: + return fmt.Errorf("failed parsing FTP message") + } + if len(f.summary) == 0 || len(f.data) == 0 { + return fmt.Errorf("failed parsing FTP message") } return nil } -func (f *FTPMessage) NextLayer() (layer string, payload []byte) { return } +func (f *FTPMessage) NextLayer() Layer { return nil } +func (f *FTPMessage) Name() LayerName { return LayerFTP } + +func checkFTP(data []byte) bool { + return (len(data) >= 4 && isDigit(data[0]) && isDigit(data[1]) && + isDigit(data[2]) && (data[3] == ' ' || data[3] == '-')) || + (len(data) >= 3 && isUpper(data[0]) && isUpper(data[1]) && isUpper(data[2])) +} diff --git a/layers/http.go b/layers/http.go index 732c3da..1141abd 100644 --- a/layers/http.go +++ b/layers/http.go @@ -59,13 +59,15 @@ func (h *HTTPMessage) Summary() string { } func (h *HTTPMessage) Parse(data []byte) error { - if !bytes.Contains(data, protohttp10) && !bytes.Contains(data, protohttp11) { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + if !bytes.Contains(buf, protohttp10) && !bytes.Contains(buf, protohttp11) { h.Request = nil h.Response = nil - return nil + return fmt.Errorf("message does not contain protocol") } - reader := bufio.NewReader(bytes.NewReader(data)) - if bytes.HasPrefix(data, protohttp11) || bytes.HasPrefix(data, protohttp10) { + reader := bufio.NewReader(bytes.NewReader(buf)) + if bytes.HasPrefix(buf, protohttp11) || bytes.HasPrefix(buf, protohttp10) { resp, err := http.ReadResponse(reader, nil) if err != nil { return err @@ -73,7 +75,7 @@ func (h *HTTPMessage) Parse(data []byte) error { h.Response = resp h.Request = nil } else { - reader := bufio.NewReader(bytes.NewReader(data)) + reader := bufio.NewReader(bytes.NewReader(buf)) req, err := http.ReadRequest(reader) if err != nil { return err @@ -84,10 +86,11 @@ func (h *HTTPMessage) Parse(data []byte) error { return nil } -func (h *HTTPMessage) NextLayer() (layer string, payload []byte) { return } +func (h *HTTPMessage) NextLayer() Layer { return nil } +func (h *HTTPMessage) Name() LayerName { return LayerHTTP } type HTTPRequestWrapper struct { - Request HTTPRequest `json:"http_request"` + Request *HTTPRequest `json:"http_request"` } type HTTPRequest struct { @@ -100,7 +103,7 @@ type HTTPRequest struct { } type HTTPResponseWrapper struct { - Response HTTPResponse `json:"http_response"` + Response *HTTPResponse `json:"http_response"` } type HTTPResponse struct { @@ -112,7 +115,7 @@ type HTTPResponse struct { func (h *HTTPMessage) MarshalJSON() ([]byte, error) { if h.Request != nil { - return json.Marshal(&HTTPRequestWrapper{Request: HTTPRequest{ + return json.Marshal(&HTTPRequestWrapper{Request: &HTTPRequest{ Host: h.Request.Host, URI: h.Request.RequestURI, Method: h.Request.Method, @@ -121,7 +124,7 @@ func (h *HTTPMessage) MarshalJSON() ([]byte, error) { Header: h.Request.Header, }}) } else if h.Response != nil { - return json.Marshal(&HTTPResponseWrapper{Response: HTTPResponse{ + return json.Marshal(&HTTPResponseWrapper{Response: &HTTPResponse{ Proto: h.Response.Proto, Status: h.Response.Status, ContentLength: int(h.Response.ContentLength), diff --git a/layers/icmp.go b/layers/icmp.go index c0351c8..83c0001 100644 --- a/layers/icmp.go +++ b/layers/icmp.go @@ -46,10 +46,12 @@ func (i *ICMPSegment) Parse(data []byte) error { if len(data) < headerSizeICMP { return fmt.Errorf("minimum header size for ICMP is %d bytes, got %d bytes", headerSizeICMP, len(data)) } - i.Type = data[0] - i.Code = data[1] - i.Checksum = binary.BigEndian.Uint16(data[2:4]) - i.Data = data[headerSizeICMP:] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + i.Type = buf[0] + i.Code = buf[1] + i.Checksum = binary.BigEndian.Uint16(buf[2:4]) + i.Data = buf[headerSizeICMP:] var pLen int switch i.Type { case 0, 3, 5, 8, 11: @@ -63,9 +65,14 @@ func (i *ICMPSegment) Parse(data []byte) error { return fmt.Errorf("minimum payload length for ICMP with type %d is %d bytes", i.Type, pLen) } i.TypeDesc, i.CodeDesc = i.typecode() + if i.TypeDesc == "Unknown" || i.CodeDesc == "Unknown" { + return fmt.Errorf("failed determining type or code") + } return nil } -func (i *ICMPSegment) NextLayer() (layer string, payload []byte) { return } + +func (i *ICMPSegment) NextLayer() Layer { return nil } +func (i *ICMPSegment) Name() LayerName { return LayerICMP } func (i *ICMPSegment) typecode() (string, string) { // https://en.wikipedia.org/wiki/Internet_Control_Message_Protocol diff --git a/layers/icmpv6.go b/layers/icmpv6.go index fde107d..51ab311 100644 --- a/layers/icmpv6.go +++ b/layers/icmpv6.go @@ -44,10 +44,12 @@ func (i *ICMPv6Segment) Parse(data []byte) error { if len(data) < headerSizeICMPv6 { return fmt.Errorf("minimum header size for ICMPv6 is %d bytes, got %d bytes", headerSizeICMPv6, len(data)) } - i.Type = data[0] - i.Code = data[1] - i.Checksum = binary.BigEndian.Uint16(data[2:headerSizeICMPv6]) - i.Data = data[headerSizeICMPv6:] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + i.Type = buf[0] + i.Code = buf[1] + i.Checksum = binary.BigEndian.Uint16(buf[2:headerSizeICMPv6]) + i.Data = buf[headerSizeICMPv6:] var pLen int switch i.Type { case 1, 2, 3, 4, 128, 129, 133: @@ -68,7 +70,8 @@ func (i *ICMPv6Segment) Parse(data []byte) error { return nil } -func (i *ICMPv6Segment) NextLayer() (layer string, payload []byte) { return } +func (i *ICMPv6Segment) NextLayer() Layer { return nil } +func (i *ICMPv6Segment) Name() LayerName { return LayerICMPv6 } func (i *ICMPv6Segment) typecode() (string, string) { // https://en.wikipedia.org/wiki/ICMPv6 diff --git a/layers/ipv4.go b/layers/ipv4.go index 5aba8c4..a64d387 100644 --- a/layers/ipv4.go +++ b/layers/ipv4.go @@ -73,17 +73,22 @@ func NewIPv4Packet(srcIP, dstIP netip.Addr, proto IPProto, payload []byte) (*IPv return nil, fmt.Errorf("malformed IPv4 address") } ipPacket := &IPv4Packet{ - Version: 4, - IHL: 5, - TotalLength: uint16(headerSizeIPv4 + len(payload)), - Flags: NewIPv4Flags(0), - TTL: 64, - Protocol: &IPv4Proto{Val: proto, Desc: protodesc(proto)}, - SrcIP: srcIP, - DstIP: dstIP, - Payload: payload, + Version: 4, + IHL: 5, + TotalLength: uint16(headerSizeIPv4 + len(payload)), + Identification: MustGenerateRandomUint16NE(), + Flags: NewIPv4Flags(2), + TTL: 128, + Protocol: &IPv4Proto{Val: proto, Desc: protodesc(proto)}, + SrcIP: srcIP, + DstIP: dstIP, + Payload: payload, } - ipPacket.HeaderChecksum, _ = CalculateIPv4Checksum(ipPacket.ToBytes()) + headerChecksum, err := CalculateIPv4Checksum(ipPacket.ToBytes()) + if err != nil { + return nil, fmt.Errorf("failed calculating checksum") + } + ipPacket.HeaderChecksum = headerChecksum return ipPacket, nil } @@ -155,38 +160,56 @@ func (p *IPv4Packet) UnmarshalBinary(data []byte) error { if len(data) < headerSizeIPv4 { return fmt.Errorf("minimum header size for IPv4 is %d bytes, got %d bytes", headerSizeIPv4, len(data)) } - versionIHL := data[0] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + versionIHL := buf[0] p.Version = versionIHL >> 4 + if p.Version != 4 { + return fmt.Errorf("unknown version") + } p.IHL = versionIHL & 15 - dscpECN := data[1] + dscpECN := buf[1] p.DSCP = dscpECN >> 2 p.DSCPDesc = dscpdesc(p.DSCP) + if p.DSCPDesc == "Unknown" { + return fmt.Errorf("unknown DSCP") + } p.ECN = dscpECN & 3 - p.TotalLength = binary.BigEndian.Uint16(data[2:4]) - p.Identification = binary.BigEndian.Uint16(data[4:6]) - flagsOffset := binary.BigEndian.Uint16(data[6:8]) + p.TotalLength = binary.BigEndian.Uint16(buf[2:4]) + if int(p.TotalLength) != len(buf) { + return fmt.Errorf("total length is not equal to actual packet size") + } + p.Identification = binary.BigEndian.Uint16(buf[4:6]) + flagsOffset := binary.BigEndian.Uint16(buf[6:8]) flags := uint8(flagsOffset >> 13) p.Flags = NewIPv4Flags(flags) p.FragmentOffset = flagsOffset & (1<<13 - 1) - p.TTL = data[8] - proto := IPProto(data[9]) - p.Protocol = &IPv4Proto{Val: proto, Desc: protodesc(proto)} - p.HeaderChecksum = binary.BigEndian.Uint16(data[headerChecksumOffsetIPv4:12]) + p.TTL = buf[8] + proto := IPProto(buf[9]) + protodesc := protodesc(proto) + if protodesc == "Unknown" { + return fmt.Errorf("unknown protocol") + } + p.Protocol = &IPv4Proto{Val: proto, Desc: protodesc} + p.HeaderChecksum = binary.BigEndian.Uint16(buf[headerChecksumOffsetIPv4:12]) var ok bool - p.SrcIP, ok = netip.AddrFromSlice(data[12:16]) + p.SrcIP, ok = netip.AddrFromSlice(buf[12:16]) if !ok { return fmt.Errorf("malformed IPv4 address") } - p.DstIP, ok = netip.AddrFromSlice(data[16:headerSizeIPv4]) + p.DstIP, ok = netip.AddrFromSlice(buf[16:headerSizeIPv4]) if !ok { return fmt.Errorf("malformed IPv4 address") } if p.IHL > 5 { offset := headerSizeIPv4 + ((p.IHL - 5) << 2) - p.Options = data[headerSizeIPv4:offset] - p.Payload = data[offset:] + if int(offset) > len(buf) { + return ErrSliceBounds + } + p.Options = buf[headerSizeIPv4:offset] + p.Payload = buf[offset:] } else { - p.Payload = data[headerSizeIPv4:] + p.Payload = buf[headerSizeIPv4:] } return nil } @@ -212,14 +235,17 @@ func protodesc(proto IPProto) string { return protodesc } -func (p *IPv4Packet) NextLayer() (string, []byte) { - layer := p.Protocol.Desc - if layer == "Unknown" { - layer = "" +func (p *IPv4Packet) NextLayer() Layer { + if next := GetLayer(LayerName(p.Protocol.Desc)); next != nil { + if err := next.Parse(p.Payload); err == nil { + return next + } } - return layer, p.Payload + return ParseNextLayer(p.Payload, nil, nil) } +func (p *IPv4Packet) Name() LayerName { return LayerIPv4 } + func dscpdesc(dscp uint8) string { // https://en.wikipedia.org/wiki/Differentiated_services var dscpdesc string @@ -271,6 +297,10 @@ func (p *IPv4Packet) PseudoHeader() *IPv4PseudoHeader { return &IPv4PseudoHeader{SrcIP: p.SrcIP, DstIP: p.DstIP, Protocol: p.Protocol, TotalLength: uint16(len(p.Payload))} } +func (p *IPv4Packet) SetPayload(payload []byte) { + p.Payload = payload +} + type IPv4PseudoHeader struct { SrcIP netip.Addr DstIP netip.Addr diff --git a/layers/ipv6.go b/layers/ipv6.go index 8efaef2..4974693 100644 --- a/layers/ipv6.go +++ b/layers/ipv6.go @@ -15,7 +15,7 @@ type TrafficClass struct { ECN uint8 } -func newTrafficiClass(tc uint8) *TrafficClass { +func newTrafficClass(tc uint8) *TrafficClass { dscpbin := tc >> 2 return &TrafficClass{ Raw: tc, @@ -44,7 +44,7 @@ type IPv6Packet struct { HopLimit uint8 SrcIP netip.Addr // The unicast IPv6 address of the sending node. DstIP netip.Addr // The IPv6 unicast or multicast address of the destination node(s). - payload []byte + Payload []byte } func (p *IPv6Packet) String() string { @@ -67,7 +67,7 @@ func (p *IPv6Packet) String() string { p.HopLimit, p.SrcIP, p.DstIP, - len(p.payload), + len(p.Payload), ) } @@ -80,21 +80,42 @@ func (p *IPv6Packet) Parse(data []byte) error { if len(data) < headerSizeIPv6 { return fmt.Errorf("minimum header size for IPv6 is %d bytes, got %d bytes", headerSizeIPv6, len(data)) } - versionTrafficFlow := binary.BigEndian.Uint32(data[0:4]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + versionTrafficFlow := binary.BigEndian.Uint32(buf[0:4]) p.Version = uint8(versionTrafficFlow >> 28) - p.TrafficClass = newTrafficiClass(uint8((versionTrafficFlow >> 20) & 0xFF)) + if p.Version != 6 { + return fmt.Errorf("unknown version") + } + p.TrafficClass = newTrafficClass(uint8((versionTrafficFlow >> 20) & 0xFF)) + if p.TrafficClass.DSCPDesc == "Unknown" { + return fmt.Errorf("unknown DSCP") + } p.FlowLabel = versionTrafficFlow & (1<<20 - 1) - p.PayloadLength = binary.BigEndian.Uint16(data[4:6]) - p.NextHeader = data[6] + p.PayloadLength = binary.BigEndian.Uint16(buf[4:6]) + p.NextHeader = buf[6] p.NextHeaderDesc = p.nextHeader() - p.HopLimit = data[7] - p.SrcIP, _ = netip.AddrFromSlice(data[8:24]) - p.DstIP, _ = netip.AddrFromSlice(data[24:headerSizeIPv6]) - p.payload = data[headerSizeIPv6:] + if p.NextHeaderDesc == "Unknown" { + return fmt.Errorf("unknown next header") + } + p.HopLimit = buf[7] + var ok bool + p.SrcIP, ok = netip.AddrFromSlice(buf[8:24]) + if !ok { + return fmt.Errorf("malformed IPv6 address") + } + p.DstIP, ok = netip.AddrFromSlice(buf[24:headerSizeIPv6]) + if !ok { + return fmt.Errorf("malformed IPv6 address") + } + p.Payload = buf[headerSizeIPv6:] + if p.PayloadLength != 0 && int(p.PayloadLength) != len(p.Payload) { + return fmt.Errorf("payload length filed is not equal to actual payload size") + } return nil } -func (p *IPv6Packet) NextLayer() (string, []byte) { +func (p *IPv6Packet) nextLayer() string { // https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers var layer string switch p.NextHeader { @@ -107,9 +128,20 @@ func (p *IPv6Packet) NextLayer() (string, []byte) { default: layer = "" } - return layer, p.payload + return layer } +func (p *IPv6Packet) NextLayer() Layer { + if next := GetLayer(LayerName(p.nextLayer())); next != nil { + if err := next.Parse(p.Payload); err == nil { + return next + } + } + return ParseNextLayer(p.Payload, nil, nil) +} + +func (p *IPv6Packet) Name() LayerName { return LayerIPv6 } + func (p *IPv6Packet) nextHeader() string { // https://en.wikipedia.org/wiki/List_of_IP_protocol_numbers var header string @@ -141,7 +173,7 @@ func (p *IPv6Packet) nextHeader() string { case 140: header = "Shim6 Protocol" default: - header = "" + header = "Unknown" } return header } diff --git a/layers/ipv6_test.go b/layers/ipv6_test.go index 5cda13d..5f917bb 100644 --- a/layers/ipv6_test.go +++ b/layers/ipv6_test.go @@ -31,11 +31,13 @@ func TestParseIPv6(t *testing.T) { HopLimit: 64, SrcIP: netip.AddrFrom16([16]byte{ 0xFD, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xF2, 0x3F, 0xD1, 0x59, 0x50, 0x48, 0x9C, 0x14}), + 0xF2, 0x3F, 0xD1, 0x59, 0x50, 0x48, 0x9C, 0x14, + }), DstIP: netip.AddrFrom16([16]byte{ 0x26, 0x20, 0x00, 0x2D, 0x40, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2B}), - payload: []byte{}, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x2B, + }), + Payload: []byte{}, } ip := &IPv6Packet{} packet, close := testPacket(t, "ipv6") diff --git a/layers/layers.go b/layers/layers.go index 46c526a..b50948a 100644 --- a/layers/layers.go +++ b/layers/layers.go @@ -2,46 +2,249 @@ package layers import ( + "bytes" + "crypto/rand" + "encoding/binary" "fmt" "unsafe" + + "github.com/shadowy-pycoder/mshark/native" +) + +const maxLenSummary = 110 + +type LayerName string + +const ( + LayerETH LayerName = "ETH" + LayerIPv4 LayerName = "IPv4" + LayerIPv6 LayerName = "IPv6" + LayerARP LayerName = "ARP" + LayerTCP LayerName = "TCP" + LayerUDP LayerName = "UDP" + LayerICMP LayerName = "ICMP" + LayerICMPv6 LayerName = "ICMPv6" + LayerDNS LayerName = "DNS" + LayerFTP LayerName = "FTP" + LayerHTTP LayerName = "HTTP" + LayerSNMP LayerName = "SNMP" + LayerSSH LayerName = "SSH" + LayerTLS LayerName = "TLS" ) -const maxLenSummary = 100 - -var LayerMap = map[string]Layer{ - "ETH": &EthernetFrame{}, - "IPv4": &IPv4Packet{}, - "IPv6": &IPv6Packet{}, - "ARP": &ARPPacket{}, - "TCP": &TCPSegment{}, - "UDP": &UDPSegment{}, - "ICMP": &ICMPSegment{}, - "ICMPv6": &ICMPv6Segment{}, - "DNS": &DNSMessage{}, - "FTP": &FTPMessage{}, - "HTTP": &HTTPMessage{}, - "SNMP": &SNMPMessage{}, - "SSH": &SSHMessage{}, - "TLS": &TLSMessage{}, +var Layers = []LayerName{ + LayerETH, + LayerIPv4, + LayerIPv6, + LayerTLS, + LayerHTTP, + LayerDNS, + LayerARP, + LayerTCP, + LayerUDP, + LayerICMP, + LayerICMPv6, + LayerSNMP, + LayerSSH, + LayerFTP, } var ( - bspace = []byte(" ") - dash = []byte("- ") - lfd = []byte("\n- ") - slfd = "\n- " - lf = []byte("\n") - crlf = []byte("\r\n") - dcrlf = []byte("\r\n\r\n") - ellipsis = []byte("...") - contdata = []byte("Continuation data") + bspace = []byte(" ") + dash = []byte("- ") + lfd = []byte("\n- ") + slfd = "\n- " + lf = []byte("\n") + crlf = []byte("\r\n") + dcrlf = []byte("\r\n\r\n") + ellipsis = []byte("...") + contdata = []byte("Continuation data") + ErrParsingAddress = fmt.Errorf("failed parsing IP address") + ErrSliceBounds = fmt.Errorf("slice bounds out of range") ) type Layer interface { fmt.Stringer Parse(data []byte) error - NextLayer() (layer string, payload []byte) + NextLayer() Layer Summary() string + Name() LayerName +} + +func parseNextLayerFallback(data []byte) Layer { + if len(data) == 0 { + return nil + } + for _, layer := range Layers { + next := GetLayer(layer) + if err := next.Parse(data); err == nil { + return next + } + } + return nil +} + +func parseNextLayerFromBytes(data []byte) Layer { + if len(data) == 0 { + return nil + } + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + var next Layer + firstByte := buf[0] + if firstByte >= 0x45 && firstByte <= 0x4F { + next = GetLayer(LayerIPv4) + if err := next.Parse(buf); err == nil { + return next + } + } + if firstByte>>4 == 6 { + next = GetLayer(LayerIPv6) + if err := next.Parse(buf); err == nil { + return next + } + } + if firstByte == HandshakeTLSVal { + next = GetLayer(LayerTLS) + if err := next.Parse(buf); err == nil { + return next + } + } + if checkFTP(buf) { + next = GetLayer(LayerFTP) + if err := next.Parse(buf); err == nil { + return next + } + } + if len(buf) > 3 { + b1 := binary.BigEndian.Uint16(buf[0:2]) + b2 := binary.BigEndian.Uint16(buf[2:4]) + if b1 == 1 && (b2 == 0x0800 || b2 == 0x86dd) { + next = GetLayer(LayerARP) + if err := next.Parse(buf); err == nil { + return next + } + } + } + if checkSNMP(buf) { + next = GetLayer(LayerSNMP) + if err := next.Parse(buf); err == nil { + return next + } + } + if len(buf) > 15 { + b1 := binary.BigEndian.Uint16(buf[12:14]) + if b1 == 0x0806 || b1 == 0x0800 || b1 == 0x86dd { + next = GetLayer(LayerETH) + if err := next.Parse(buf); err == nil { + return next + } + } + } + if bytes.Contains(buf, protohttp10) || bytes.Contains(buf, protohttp11) { + next = GetLayer(LayerHTTP) + if err := next.Parse(buf); err == nil { + return next + } + } + if bytes.Contains(buf, protoSSH) { + next = GetLayer(LayerSSH) + if err := next.Parse(buf); err == nil { + return next + } + } + return nil +} + +func addrMatch(src, dst *uint16, ports []uint16) bool { + var srcPort, dstPort uint16 + if src != nil { + srcPort = *src + } + if dst != nil { + dstPort = *dst + } + for _, port := range ports { + if srcPort == port || dstPort == port { + return true + } + } + return false +} + +func parseNextLayerFromPorts(data []byte, src, dst *uint16) Layer { + if len(data) == 0 { + return nil + } + var next Layer + switch { + case addrMatch(src, dst, []uint16{53, 5353, 853, 5355}): + next = GetLayer(LayerDNS) + case addrMatch(src, dst, []uint16{80, 8080, 8000, 8888, 81, 591, 5911}): + next = GetLayer(LayerHTTP) + case addrMatch(src, dst, []uint16{161, 162, 10161, 10162, 1161, 2161}): + next = GetLayer(LayerSNMP) + case addrMatch(src, dst, []uint16{21, 20, 2121, 8021}): + next = GetLayer(LayerFTP) + case addrMatch(src, dst, []uint16{22, 2222, 2200, 222, 2022}): + next = GetLayer(LayerSSH) + case addrMatch(src, dst, []uint16{443, 465, 993, 995, 8443, 9443, 10443, 8444, 5228}): + next = GetLayer(LayerTLS) + default: + return nil + } + if err := next.Parse(data); err == nil { + return next + } else { + return nil + } +} + +func ParseNextLayer(data []byte, src, dst *uint16) Layer { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + var next Layer + if src != nil || dst != nil { + if next = parseNextLayerFromPorts(buf, src, dst); next != nil { + return next + } + } + return parseNextLayerFromBytes(buf) +} + +func GetLayer(layer LayerName) Layer { + switch layer { + case LayerETH: + return &EthernetFrame{} + case LayerIPv4: + return &IPv4Packet{} + case LayerIPv6: + return &IPv6Packet{} + case LayerARP: + return &ARPPacket{} + case LayerTCP: + return &TCPSegment{} + case LayerUDP: + return &UDPSegment{} + case LayerICMP: + return &ICMPSegment{} + case LayerICMPv6: + return &ICMPv6Segment{} + case LayerDNS: + return &DNSMessage{} + case LayerFTP: + return &FTPMessage{} + case LayerHTTP: + return &HTTPMessage{} + case LayerSNMP: + return &SNMPMessage{} + case LayerSSH: + return &SSHMessage{} + case LayerTLS: + return &TLSMessage{} + default: + return nil + } } func bytesToStr(b []byte) string { @@ -65,3 +268,53 @@ func add16WithCarryWrapAround(x, y uint16) uint16 { sum32 = (sum32 & 0xFFFF) + (sum32 >> 16) return uint16(sum32) } + +func isDigit(b byte) bool { + return b >= '0' && b <= '9' +} + +func isUpper(b byte) bool { + return b >= 'A' && b <= 'Z' +} + +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + + return b, nil +} + +func GenerateRandomUint16LE() (uint16, error) { + b, err := GenerateRandomBytes(2) + if err != nil { + return 0, err + } + return binary.LittleEndian.Uint16(b), nil +} + +func GenerateRandomUint16BE() (uint16, error) { + b, err := GenerateRandomBytes(2) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint16(b), nil +} + +func GenerateRandomUint16NE() (uint16, error) { + b, err := GenerateRandomBytes(2) + if err != nil { + return 0, err + } + return native.Endian.Uint16(b), nil +} + +func MustGenerateRandomUint16NE() uint16 { + rn, err := GenerateRandomUint16NE() + if err != nil { + panic(err) + } + return rn +} diff --git a/layers/snmp.go b/layers/snmp.go index 984de12..b232d7e 100644 --- a/layers/snmp.go +++ b/layers/snmp.go @@ -18,8 +18,18 @@ func (s *SNMPMessage) Summary() string { } func (s *SNMPMessage) Parse(data []byte) error { - s.Payload = data + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + if !checkSNMP(buf) { + return fmt.Errorf("not ASN.1 SEQUENCE") + } + s.Payload = buf return nil } -func (s *SNMPMessage) NextLayer() (layer string, payload []byte) { return } +func (s *SNMPMessage) NextLayer() Layer { return nil } +func (s *SNMPMessage) Name() LayerName { return LayerSNMP } + +func checkSNMP(data []byte) bool { + return len(data) > 6 && data[0] == 0x30 && (data[2] == 0x02 || data[2] == 0x04) +} diff --git a/layers/ssh.go b/layers/ssh.go index e128906..ea7e98d 100644 --- a/layers/ssh.go +++ b/layers/ssh.go @@ -9,6 +9,8 @@ import ( const messageSizeSSH = 6 +var protoSSH = []byte("SSH-") + type Message struct { PacketLength uint32 PaddingLength uint8 @@ -85,41 +87,52 @@ func (s *SSHMessage) Parse(data []byte) error { if len(data) < messageSizeSSH { return fmt.Errorf("minimum message size for SSH is %d bytes, got %d bytes", messageSizeSSH, len(data)) } - s.Protocol = "" - s.Messages = nil - if bytes.HasSuffix(data, crlf) { - s.Protocol = bytesToStr(bytes.TrimSuffix(data, crlf)) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + if bytes.HasSuffix(buf, crlf) { + p := bytes.TrimSuffix(buf, crlf) + if !bytes.Contains(p, protoSSH) { + return fmt.Errorf("message should contain SSH-") + } + s.Protocol = bytesToStr(p) return nil } s.Messages = make([]*Message, 0, 3) - for len(data) > 0 { + for len(buf) > 0 { + if len(buf) < 4 { + return ErrSliceBounds + } m := &Message{} s.Messages = append(s.Messages, m) - plen := binary.BigEndian.Uint32(data[0:4]) + plen := binary.BigEndian.Uint32(buf[0:4]) if plen > 0xffff { - m.Payload = data + m.Payload = buf break } - m.MesssageType = data[5] + if len(buf) < 5 { + return ErrSliceBounds + } + m.MesssageType = buf[5] if m.MesssageTypeDesc = mtypedesc(m.MesssageType); m.MesssageTypeDesc == "Unknown" { - m.Payload = data + m.Payload = buf break } m.PacketLength = plen - m.PaddingLength = data[4] + m.PaddingLength = buf[4] offset := int(messageSizeSSH + m.PacketLength - 2) - if offset <= len(data) { - m.Payload = data[messageSizeSSH:offset] - data = data[offset:] + if offset <= len(buf) { + m.Payload = buf[messageSizeSSH:offset] + buf = buf[offset:] } else { - m.Payload = data[messageSizeSSH:] + m.Payload = buf[messageSizeSSH:] break } } return nil } -func (s *SSHMessage) NextLayer() (layer string, payload []byte) { return } +func (s *SSHMessage) NextLayer() Layer { return nil } +func (s *SSHMessage) Name() LayerName { return LayerSSH } // https://www.iana.org/assignments/ssh-parameters/ssh-parameters.xhtml func mtypedesc(mtype uint8) string { diff --git a/layers/tcp.go b/layers/tcp.go index 56658cc..0df3a58 100644 --- a/layers/tcp.go +++ b/layers/tcp.go @@ -62,7 +62,7 @@ type TCPSegment struct { // indicating the last urgent data byte. UrgentPointer uint16 Options []byte // The length of this field is determined by the data offset field. - payload []byte + Payload []byte } func (t *TCPSegment) String() string { @@ -94,7 +94,7 @@ func (t *TCPSegment) String() string { t.UrgentPointer, len(t.Options), t.Options, - len(t.payload), + len(t.Payload), ) } @@ -105,7 +105,7 @@ func (t *TCPSegment) Summary() string { t.DstPort, t.Flags, t.WindowSize, - len(t.payload), + len(t.Payload), ) } @@ -114,43 +114,30 @@ func (t *TCPSegment) Parse(data []byte) error { if len(data) < headerSizeTCP { return fmt.Errorf("minimum header size for TCP is %d bytes, got %d bytes", headerSizeTCP, len(data)) } - t.SrcPort = binary.BigEndian.Uint16(data[0:2]) - t.DstPort = binary.BigEndian.Uint16(data[2:4]) - t.SeqNumber = binary.BigEndian.Uint32(data[4:8]) - t.AckNumber = binary.BigEndian.Uint32(data[8:12]) - offsetReservedFlags := binary.BigEndian.Uint16(data[12:14]) + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + t.SrcPort = binary.BigEndian.Uint16(buf[0:2]) + t.DstPort = binary.BigEndian.Uint16(buf[2:4]) + t.SeqNumber = binary.BigEndian.Uint32(buf[4:8]) + t.AckNumber = binary.BigEndian.Uint32(buf[8:12]) + offsetReservedFlags := binary.BigEndian.Uint16(buf[12:14]) t.DataOffset = uint8(offsetReservedFlags >> 12) t.Reserved = uint8((offsetReservedFlags >> 8) & 15) t.Flags = newTCPFlags(uint8(offsetReservedFlags & (1<<8 - 1))) - t.WindowSize = binary.BigEndian.Uint16(data[14:16]) - t.Checksum = binary.BigEndian.Uint16(data[16:18]) - t.UrgentPointer = binary.BigEndian.Uint16(data[18:headerSizeTCP]) - t.Options = data[headerSizeTCP : t.DataOffset<<2] - t.payload = data[t.DataOffset<<2:] + t.WindowSize = binary.BigEndian.Uint16(buf[14:16]) + t.Checksum = binary.BigEndian.Uint16(buf[16:18]) + t.UrgentPointer = binary.BigEndian.Uint16(buf[18:headerSizeTCP]) + offset := t.DataOffset << 2 + if len(buf) < int(offset) || int(offset) < headerSizeTCP { + return ErrSliceBounds + } + t.Options = buf[headerSizeTCP:offset] + t.Payload = buf[offset:] return nil } -func (t *TCPSegment) NextLayer() (string, []byte) { - return nextAppLayer(t.SrcPort, t.DstPort), t.payload +func (t *TCPSegment) NextLayer() Layer { + return ParseNextLayer(t.Payload, &t.SrcPort, &t.DstPort) } -func nextAppLayer(src, dst uint16) string { - var layer string - switch { - case src == 20 || dst == 20 || src == 21 || dst == 21: - layer = "FTP" - case src == 22 || dst == 22: - layer = "SSH" - case src == 53 || dst == 53: - layer = "DNS" - case src == 80 || dst == 80: - layer = "HTTP" - case src == 161 || dst == 161 || src == 162 || dst == 162: - layer = "SNMP" - case src == 443 || dst == 443: - layer = "TLS" - default: - layer = "" - } - return layer -} +func (t *TCPSegment) Name() LayerName { return LayerTCP } diff --git a/layers/tcp_test.go b/layers/tcp_test.go index b5627ce..3c924e1 100644 --- a/layers/tcp_test.go +++ b/layers/tcp_test.go @@ -46,7 +46,7 @@ func TestParseTCP(t *testing.T) { 0x02, 0x04, 0x05, 0xb4, 0x04, 0x02, 0x08, 0x0a, 0xac, 0xf8, 0x48, 0x3e, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07, }, - payload: []byte{}, + Payload: []byte{}, } tcp := &TCPSegment{} packet, close := testPacket(t, "tcp") diff --git a/layers/tls.go b/layers/tls.go index fff4e1f..cf110c0 100644 --- a/layers/tls.go +++ b/layers/tls.go @@ -86,11 +86,17 @@ type ServerName struct { } func (sn *ServerName) Parse(data []byte) error { + if len(data) < 9 { + return ErrSliceBounds + } sn.Type = binary.BigEndian.Uint16(data[0:2]) sn.Length = binary.BigEndian.Uint16(data[2:4]) sn.SNListLength = binary.BigEndian.Uint16(data[4:6]) sn.SNType = data[6] sn.SNNameLength = binary.BigEndian.Uint16(data[7:9]) + if int(9+sn.SNNameLength) > len(data) { + return ErrSliceBounds + } sn.SNName = string(data[9 : 9+sn.SNNameLength]) return nil } @@ -539,29 +545,37 @@ func (t *TLSMessage) printRecords() string { } func (t *TLSMessage) Parse(data []byte) error { + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) t.Records = make([]*Record, 0, 5) - if len(data) < headerSizeTLS { + if len(buf) < headerSizeTLS { return ErrTLSTooShort } - for len(data) > 0 { - ctype := data[0] + for i := 0; len(buf) > 0; i++ { + ctype := buf[0] ctdesc := ctdesc(ctype) if ctdesc == "Unknown" { + if i == 0 { + return fmt.Errorf("unknown content type") + } break } - if len(data) < 3 { + if len(buf) < 3 { break } - ver := binary.BigEndian.Uint16(data[1:3]) + ver := binary.BigEndian.Uint16(buf[1:3]) verdesc := verdesc(ver) if verdesc == "Unknown" { + if i == 0 { + return fmt.Errorf("unknown version") + } break } - if len(data) < headerSizeTLS { + if len(buf) < headerSizeTLS { break } - rlen := binary.BigEndian.Uint16(data[3:headerSizeTLS]) - rb := min(uint16(headerSizeTLS+rlen), uint16(len(data))) + rlen := binary.BigEndian.Uint16(buf[3:headerSizeTLS]) + rb := min(uint16(headerSizeTLS+rlen), uint16(len(buf))) if rb < headerSizeTLS { break } @@ -570,16 +584,17 @@ func (t *TLSMessage) Parse(data []byte) error { ContentTypeDesc: ctdesc, Version: &TLSVersion{Val: ver, Desc: verdesc}, Length: rlen, - Data: data[headerSizeTLS:rb], + Data: buf[headerSizeTLS:rb], } t.Records = append(t.Records, r) - data = data[rb:] + buf = buf[rb:] } - t.Data = data + t.Data = buf return nil } -func (t *TLSMessage) NextLayer() (layer string, payload []byte) { return } +func (t *TLSMessage) NextLayer() Layer { return nil } +func (t *TLSMessage) Name() LayerName { return LayerTLS } func ctdesc(ct uint8) string { // https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-5 diff --git a/layers/udp.go b/layers/udp.go index 96eff98..bca6ad5 100644 --- a/layers/udp.go +++ b/layers/udp.go @@ -19,15 +19,8 @@ type UDPSegment struct { Payload []byte } -func NewUDPSegment(srcPort, dstPort uint16, payload []byte, pseudo *[]byte) (*UDPSegment, error) { +func NewUDPSegment(srcPort, dstPort uint16, payload []byte) (*UDPSegment, error) { udp := &UDPSegment{SrcPort: srcPort, DstPort: dstPort, UDPLength: uint16(headerSizeUDP + len(payload)), Payload: payload} - if pseudo != nil { - var err error - udp.Checksum, err = CalculateUDPChecksum(append(*pseudo, udp.ToBytes()...)) - if err != nil { - return nil, err - } - } return udp, nil } @@ -71,11 +64,13 @@ func (u *UDPSegment) UnmarshalBinary(data []byte) error { if len(data) < headerSizeUDP { return fmt.Errorf("minimum header size for UDP is %d bytes, got %d bytes", headerSizeUDP, len(data)) } - u.SrcPort = binary.BigEndian.Uint16(data[0:2]) - u.DstPort = binary.BigEndian.Uint16(data[2:4]) - u.UDPLength = binary.BigEndian.Uint16(data[4:6]) - u.Checksum = binary.BigEndian.Uint16(data[6:headerSizeUDP]) - u.Payload = data[headerSizeUDP:] + buf := make([]byte, 0, len(data)) + buf = append(buf, data...) + u.SrcPort = binary.BigEndian.Uint16(buf[0:2]) + u.DstPort = binary.BigEndian.Uint16(buf[2:4]) + u.UDPLength = binary.BigEndian.Uint16(buf[4:6]) + u.Checksum = binary.BigEndian.Uint16(buf[6:headerSizeUDP]) + u.Payload = buf[headerSizeUDP:] return nil } @@ -84,8 +79,19 @@ func (u *UDPSegment) Parse(data []byte) error { return u.UnmarshalBinary(data) } -func (u *UDPSegment) NextLayer() (string, []byte) { - return nextAppLayer(u.SrcPort, u.DstPort), u.Payload +func (u *UDPSegment) NextLayer() Layer { + return ParseNextLayer(u.Payload, &u.SrcPort, &u.DstPort) +} + +func (u *UDPSegment) Name() LayerName { return LayerUDP } + +func (u *UDPSegment) SetChecksum(pseudo []byte) error { + checksum, err := CalculateUDPChecksum(append(pseudo, u.ToBytes()...)) + if err != nil { + return err + } + u.Checksum = checksum + return nil } func CalculateUDPChecksum(data []byte) (uint16, error) { diff --git a/mshark.go b/mshark.go index 67b9eee..d54879b 100644 --- a/mshark.go +++ b/mshark.go @@ -8,23 +8,21 @@ import ( "net" "os" "os/signal" + "strings" "time" - "github.com/mdlayher/packet" - "github.com/packetcap/go-pcap/filter" "github.com/shadowy-pycoder/mshark/layers" - "golang.org/x/net/bpf" + "github.com/shadowy-pycoder/mshark/network" ) -const unixEthPAll int = 0x03 - -var colorMap = map[int]string{ +var colorMap = map[int]string{ // TODO (shadowy-pycoder): add colors from shadowy-pycoder/colors 0: "\033[37m", 1: "\033[36m", 2: "\033[32m", 3: "\033[33m", 4: "\033[35m", } +var packetDelimeter = "\033[37m" + strings.Repeat("─", 66) + "\033[0m" var _ PacketWriter = &Writer{} @@ -80,23 +78,22 @@ func (mw *Writer) printPacket(layer layers.Layer, layerNum int) { // Timestamps are to be generated by the calling code. func (mw *Writer) WritePacket(timestamp time.Time, data []byte) error { mw.packets++ - fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05-0700")) - fmt.Fprintln(mw.w, "==================================================================") - next := layers.LayerMap["ETH"] + fmt.Fprintf(mw.w, "- Packet: %d Timestamp: %s\n", mw.packets, timestamp.Format("2006-01-02T15:04:05.000000-0700")) + fmt.Fprintln(mw.w, packetDelimeter) + next := layers.GetLayer(layers.LayerETH) + if next == nil { + return nil + } if err := next.Parse(data); err != nil { return err } var layerNum int mw.printPacket(next, layerNum) for { - name, data := next.NextLayer() - if name == "" || data == nil || len(data) == 0 { + next = next.NextLayer() + if next == nil { return nil } - next = layers.LayerMap[name] - if err := next.Parse(data); err != nil { - return err - } layerNum++ mw.printPacket(next, layerNum) } @@ -143,39 +140,15 @@ func (mw *Writer) WriteHeader(c *Config) error { // OpenLive opens a live capture based on the given configuration and writes // all captured packets to the given PacketWriters. func OpenLive(conf *Config, pw ...PacketWriter) error { - packetcfg := packet.Config{} - - // setting up filter - if conf.Expr != "" { - e := filter.NewExpression(conf.Expr) - f := e.Compile() - instructions, err := f.Compile() - if err != nil { - return fmt.Errorf("failed to compile filter into instructions: %v", err) - } - raw, err := bpf.Assemble(instructions) - if err != nil { - return fmt.Errorf("bpf assembly failed: %v", err) - } - packetcfg.Filter = raw - } - - // opening connection - c, err := packet.Listen(conf.Device, packet.Raw, unixEthPAll, &packetcfg) - if err != nil { - if errors.Is(err, os.ErrPermission) { - return fmt.Errorf("permission denied (try setting CAP_NET_RAW capability): %v", err) - } - return fmt.Errorf("failed to listen: %v", err) - } - + lc := &network.ListenConfig{Device: conf.Device, FilterExpr: conf.Expr} // setting promisc mode if conf.Device.Name != "any" { - if err := c.SetPromiscuous(conf.Promisc); err != nil { - return fmt.Errorf("unable to set promiscuous mode: %v", err) - } + lc.Promiscuous = &conf.Promisc + } + c, err := network.ListenPacket(lc) + if err != nil { + return err } - // timeout if conf.Timeout > 0 { if err := c.SetDeadline(time.Now().Add(conf.Timeout)); err != nil { diff --git a/network/network.go b/network/network.go index a9a1bf7..bdef9a4 100644 --- a/network/network.go +++ b/network/network.go @@ -3,6 +3,7 @@ package network import ( "bufio" + "errors" "fmt" "net" "net/netip" @@ -10,13 +11,69 @@ import ( "os/exec" "strings" "text/tabwriter" + + "github.com/mdlayher/packet" + "github.com/packetcap/go-pcap/filter" + "golang.org/x/net/bpf" ) +const ETH_P_ALL int = 0x03 + var ( BroadcastMAC = net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} LoopbackMAC = net.HardwareAddr{0x00, 0x00, 0x00, 0x00, 0x00, 0x00} ) +type ListenConfig struct { + Device *net.Interface // network interface which to bind to, if not specified default interface is used + Protocol int // network protocol, defaults to ETH_P_ALL + Promiscuous *bool // enable or disable promiscuous mode + FilterExpr string // packet filter expression like in tcpdump +} + +func ListenPacket(conf *ListenConfig) (*packet.Conn, error) { + packetcfg := packet.Config{} + // setting up filter + if conf.FilterExpr != "" { + e := filter.NewExpression(conf.FilterExpr) + f := e.Compile() + instructions, err := f.Compile() + if err != nil { + return nil, fmt.Errorf("failed to compile filter into instructions: %v", err) + } + raw, err := bpf.Assemble(instructions) + if err != nil { + return nil, fmt.Errorf("bpf assembly failed: %v", err) + } + packetcfg.Filter = raw + } + if conf.Device == nil { + var err error + conf.Device, err = GetDefaultInterface() + if err != nil { + return nil, err + } + } + if conf.Protocol == 0 { + conf.Protocol = ETH_P_ALL + } + // opening connection + c, err := packet.Listen(conf.Device, packet.Raw, conf.Protocol, &packetcfg) + if err != nil { + if errors.Is(err, os.ErrPermission) { + return nil, fmt.Errorf("permission denied (try setting CAP_NET_RAW capability): %v", err) + } + return nil, fmt.Errorf("failed to listen: %v", err) + } + // setting promisc mode + if conf.Promiscuous != nil { + if err := c.SetPromiscuous(*conf.Promiscuous); err != nil { + return nil, fmt.Errorf("unable to set promiscuous mode: %v", err) + } + } + return c, nil +} + // InterfaceByName returns the interface specified by name. func InterfaceByName(name string) (*net.Interface, error) { var ( @@ -42,7 +99,7 @@ func InterfaceByName(name string) (*net.Interface, error) { return in, nil } -func DisplayInterfaces() error { +func DisplayInterfaces(includeAny bool) error { w := new(tabwriter.Writer) w.Init(os.Stdout, 0, 0, 2, ' ', tabwriter.TabIndent) ifaces, err := net.Interfaces() @@ -50,7 +107,9 @@ func DisplayInterfaces() error { return fmt.Errorf("failed to get network interfaces: %v", err) } fmt.Fprintln(w, "Index\tName\tFlags") - fmt.Fprintln(w, "0\tany\tUP") + if includeAny { + fmt.Fprintln(w, "0\tany\tUP") + } for _, iface := range ifaces { fmt.Fprintf(w, "%d\t%s\t%s\n", iface.Index, iface.Name, strings.ToUpper(iface.Flags.String())) } @@ -80,6 +139,21 @@ func GetDefaultInterface() (*net.Interface, error) { return net.InterfaceByName(defaultInterface) } +func GetDefaultInterfaceFromRoute() (*net.Interface, error) { + cmd := exec.Command("sh", "-c", `ip -4 route get 8.8.8.8 | tr -d '\n'`) + routeRaw, err := cmd.Output() + if err != nil { + return nil, err + } + routeFields := strings.Fields(string(routeRaw)) + for i, f := range routeFields { + if f == "dev" && i+1 < len(routeFields) && routeFields[i+1] != "tun" { + return net.InterfaceByName(routeFields[i+1]) + } + } + return nil, fmt.Errorf("failed getting default interface from route") +} + func GetDefaultGatewayIPv4() (netip.Addr, error) { cmd := exec.Command("sh", "-c", `ip -4 route show 0.0.0.0/0 | awk '{print $3 " " $5}'`) ipdevRaw, err := cmd.Output() @@ -108,6 +182,22 @@ func GetDefaultGatewayIPv4() (netip.Addr, error) { return netip.Addr{}, fmt.Errorf("gateway IPv4 not found ") } +func GetDefaultGatewayIPv4FromRoute() (netip.Addr, error) { + cmd := exec.Command("sh", "-c", `ip -4 route get 8.8.8.8 | awk '{print $3}' | tr -d '\n'`) + ipstrRaw, err := cmd.Output() + if err != nil { + return netip.Addr{}, err + } + ip, err := netip.ParseAddr(string(ipstrRaw)) + if err != nil { + return netip.Addr{}, err + } + if !ip.IsValid() || !ip.Is4() { + return netip.Addr{}, fmt.Errorf("failed getting default gateway from route") + } + return ip, nil +} + func GetGatewayIPv4FromInterface(iface string) (netip.Addr, error) { cmd := exec.Command("sh", "-c", fmt.Sprintf("ip -4 route show dev %s", iface)) routes, err := cmd.Output() @@ -146,3 +236,48 @@ func GetIPv4PrefixFromInterface(iface *net.Interface) (netip.Prefix, error) { } return netip.Prefix{}, fmt.Errorf("no IPv4 prefix found") } + +func IsLocalAddress(addr string) bool { + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + ip := net.ParseIP(host) + if ip != nil { + return ip.IsLoopback() + } + host = strings.ToLower(host) + return strings.HasSuffix(host, ".local") || host == "localhost" +} + +// AddrEqual compares two address strings and returns true if they are equal. +// +// It treats loopback and unspecified IPs as equivalent. Returns false in case of inequality or error. +func AddrEqual(a, b string) bool { + if a == "" || b == "" { + return false + } + addr1, err := netip.ParseAddrPort(a) + if err != nil { + return false + } + addr2, err := netip.ParseAddrPort(b) + if err != nil { + return false + } + if addr1.Addr().IsLoopback() { + if addr1.Addr().Is4In6() || addr1.Addr().Is4() { + addr1 = netip.AddrPortFrom(netip.IPv4Unspecified(), addr1.Port()) + } else { + addr1 = netip.AddrPortFrom(netip.IPv6Unspecified(), addr1.Port()) + } + } + if addr2.Addr().IsLoopback() { + if addr2.Addr().Is4In6() || addr2.Addr().Is4() { + addr2 = netip.AddrPortFrom(netip.IPv4Unspecified(), addr2.Port()) + } else { + addr2 = netip.AddrPortFrom(netip.IPv6Unspecified(), addr2.Port()) + } + } + return addr1.Compare(addr2) == 0 +} diff --git a/version.go b/version.go index 1757501..3604c1f 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package mshark -const Version string = "mshark v0.0.9" +const Version string = "mshark v0.0.15"