- 
                Notifications
    You must be signed in to change notification settings 
- Fork 5.2k
Unroll SequenceEqual(ref byte, ref byte, nuint) in JIT #83945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Tagging subscribers to this area: @JulieLeeMSFT, @jakobbotsch, @kunalspathak Issue DetailsUnroll  bool Test(ReadOnlySpan<byte> data)
{
    return "hello world!"u8.SequenceEqual(data);
}Codegen diff: https://www.diffchecker.com/0VOpmvMj/ LimitationsUnfortunately, it works only when a constant span (either RVA or e.g.  
 In theory, JIT is smart enough to perform things like: if (x == 42)
{
    Foo(x); // x will be replaced with 42
}via AssertProp, but in this case it's a bit more complicated than that. Perhaps, we can assist it with  MotivationMainly, these comparisons in TE. Benchmarks[Benchmark]
public int TE_Json()
{
    return GetRequestType("/json"u8);
}
[Benchmark]
public int TE_Plaintext()
{
    return GetRequestType("/plaintext"u8);
}
[MethodImpl(MethodImplOptions.NoInlining)]
private static int GetRequestType(ReadOnlySpan<byte> path)
{
    // Simulate TE scenario
    if (path.Length == 10 && Paths.Plaintext.SequenceEqual(path))
    {
        return 1;
    }
    else if (path.Length == 5 && Paths.Json.SequenceEqual(path))
    {
        return 2;
    }
    return 3;
}
static byte[] data1 = new byte[100];
static byte[] data2 = new byte[100];
[Benchmark]
public bool Equals_15()
{
    return data1.AsSpan(0, 15).SequenceEqual(data2.AsSpan(0, 15));
}
 
 | 
| Can be dasm for  G_M000_IG01:                ;; offset=0000H
G_M000_IG02:                ;; offset=0000H
       488B01               mov      rax, bword ptr [rcx]
       8B5108               mov      edx, dword ptr [rcx+08H]
       83FA0B               cmp      edx, 11
       7404                 je       SHORT G_M000_IG04
G_M000_IG03:                ;; offset=000BH
       33C0                 xor      eax, eax
       EB24                 jmp      SHORT G_M000_IG05
G_M000_IG04:                ;; offset=000FH
       48BA68656C6C6F20776F mov      rdx, 0x6F77206F6C6C6568
       483310               xor      rdx, qword ptr [rax]
       48B96C6F20776F726C64 mov      rcx, 0x646C726F77206F6C
       48334803             xor      rcx, qword ptr [rax+03H]
       480BD1               or       rdx, rcx
       0F94C0               sete     al
       0FB6C0               movzx    rax, al
G_M000_IG05:                ;; offset=0033H
       C3                   ret
; Total bytes of code 52? The assembly above is produced by this simple C# approach. Codeusing System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
ReadOnlySpan<byte> test = "hello world"u8;
Console.WriteLine(Test1(test));
#if !DEBUG
for (int i = 0; i < 100; ++i)
{
    if (i % 10 == 0) Thread.Sleep(100);
    _ = Test1(test);
}
#endif
static bool Test1(ReadOnlySpan<byte> data) => "hello world"u8.FastSequenceEqual(data);
internal static class MySpanExtensions
{
    [MethodImpl(MethodImplOptions.AggressiveInlining)]
    public static bool FastSequenceEqual(this ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
    {
        nuint len = (uint)left.Length;
        if ((uint)right.Length != len) return false;
        if (len >= sizeof(long) && len <= 2 * sizeof(long))
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);
            long l0 = Unsafe.ReadUnaligned<long>(ref leftRef);
            long l1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref leftRef, len - sizeof(long)));
            long r0 = Unsafe.ReadUnaligned<long>(ref rightRef);
            long r1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref rightRef, len - sizeof(long)));
            long t0 = l0 ^ r0;
            long t1 = l1 ^ r1;
            long t = t0 | t1;
            return t == 0;
        }
        throw new NotSupportedException();
    }
}PS: the XOR-trick here is 👍🏻 | 
| // We're going to emit something like the following: | ||
| // | ||
| // bool result = ((*(int*)leftArg ^ *(int*)rightArg) | | ||
| // (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0; | ||
| // | ||
| // ^ in the given example we unroll for length=5 | ||
| // | ||
| // In IR: | ||
| // | ||
| // * EQ int | ||
| // +--* OR int | ||
| // | +--* XOR int | ||
| // | | +--* IND int | ||
| // | | | \--* LCL_VAR byref V1 | ||
| // | | \--* IND int | ||
| // | | \--* LCL_VAR byref V2 | ||
| // | \--* XOR int | ||
| // | +--* IND int | ||
| // | | \--* ADD byref | ||
| // | | +--* LCL_VAR byref V1 | ||
| // | | \--* CNS_INT int 1 | ||
| // | \--* IND int | ||
| // | \--* ADD byref | ||
| // | +--* LCL_VAR byref V2 | ||
| // | \--* CNS_INT int 1 | ||
| // \--* CNS_INT int 0 | ||
| // | ||
| GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def()); | ||
| GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def()); | ||
| GenTree* lXor = comp->gtNewOperNode(GT_XOR, TYP_INT, l1Indir, r1Indir); | ||
| GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth); | ||
| GenTree* l2AddOffs = comp->gtNewOperNode(GT_ADD, lArg->TypeGet(), lArgClone, l2Offs); | ||
| GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs); | ||
| GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same | ||
| GenTree* r2AddOffs = comp->gtNewOperNode(GT_ADD, rArg->TypeGet(), rArgClone, r2Offs); | ||
| GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs); | ||
| GenTree* rXor = comp->gtNewOperNode(GT_XOR, TYP_INT, l2Indir, r2Indir); | ||
| GenTree* resultOr = comp->gtNewOperNode(GT_OR, TYP_INT, lXor, rXor); | ||
| GenTree* zeroCns = comp->gtNewIconNode(0); | ||
| result = comp->gtNewOperNode(GT_EQ, TYP_INT, resultOr, zeroCns); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure this is better than the naive version for ARM64 with CCMPs? What is the ARM64 codegen diff if you create AND(EQ(IND, IND), EQ(IND, IND)) instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current codegen is (comparing 16 bytes):
F9400001          ldr     x1, [x0]
F9400043          ldr     x3, [x2]
CA030021          eor     x1, x1, x3
F9400000          ldr     x0, [x0]
F9400042          ldr     x2, [x2]
CA020000          eor     x0, x0, x2
AA000020          orr     x0, x1, x0
F100001F          cmp     x0, #0
9A9F17E0          cset    x0, eqcmp version presumably needs ifConversion path? Because here is what I see when I follow your suggestion:
F9400001          ldr     x1, [x0]
F9400043          ldr     x3, [x2]
EB03003F          cmp     x1, x3
9A9F17E1          cset    x1, eq
F9400000          ldr     x0, [x0]
F9400042          ldr     x2, [x2]
EB02001F          cmp     x0, x2
9A9F17E0          cset    x0, eq
EA00003F          tst     x1, x0
9A9F07E0          cset    x0, neso we need to either do this opt in codegen or earlier for that. For me arm64 codegen doesn't look too bad, it's still better than not unrolled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this should not need if-conversion. Are you calling lowering on these new nodes? I would expect TryLowerAndOrToCCMP to kick in and the ARM64 "naive" IR to result in ldr, ldr, ldr, ldr, cmp, ccmp, cset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this should not need if-conversion. Are you calling lowering on these new nodes? I would expect
TryLowerAndOrToCCMPto kick in and the ARM64 "naive" IR to result in ldr, ldr, ldr, ldr, cmp, ccmp, cset.
still doesn't want to convert to CCMP, IsInvariantInRange check fails, presumably because of IND side effects. Still, I think the current version is better than non-unrolled
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should be able to insert it in the right order so that there is no interference, e.g. probably
t0 = IND
t1 = IND
t2 = IND
t3 = IND
t4 = EQ(t0, t1)
t5 = EQ(t2, t3)
t6 = AND(t4, t5)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although it's a bit odd there would be interference even with
t0 = IND
t1 = IND
t2 = EQ(t0, t1)
t3 = IND
t4 = IND
t5 = EQ(t3, t4)
t6 = AND(t2, t5)Probably something I should take a look at.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but just in case I pushed a change to move all IND nodes to the front
| I have a general question: why needs this to be done in JIT and not in managed code? Is it about throughput and / or IL-size? To implement this in pure C# something like  Example managed implementation[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool SequenceEqual(this ReadOnlySpan<byte> left, ReadOnlySpan<byte> right)
{
    nuint len = (uint)left.Length;
    if ((uint)right.Length != len) return false;
    if (/* missing piece */ RuntimeHelpers.IsKnownConstant(left))
    {
        if (len >= sizeof(int) && len <= 2 * sizeof(int))
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);
            int l0 = Unsafe.ReadUnaligned<int>(ref leftRef);
            int r0 = Unsafe.ReadUnaligned<int>(ref rightRef);
            int l1 = Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref leftRef, len - sizeof(int)));
            int r1 = Unsafe.ReadUnaligned<int>(ref Unsafe.Add(ref rightRef, len - sizeof(int)));
            int t0 = l0 ^ r0;
            int t1 = l1 ^ r1;
            int t = t0 | t1;
            return t == 0;
        }
        if (len >= sizeof(long) && len <= 2 * sizeof(long))
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);
            long l0 = Unsafe.ReadUnaligned<long>(ref leftRef);
            long r0 = Unsafe.ReadUnaligned<long>(ref rightRef);
            long l1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref leftRef, len - sizeof(long)));
            long r1 = Unsafe.ReadUnaligned<long>(ref Unsafe.Add(ref rightRef, len - sizeof(long)));
            long t0 = l0 ^ r0;
            long t1 = l1 ^ r1;
            long t = t0 | t1;
            return t == 0;
        }
        if (Vector128.IsHardwareAccelerated && len >= (uint)Vector128<byte>.Count && len <= 2 * (uint)Vector128<byte>.Count)
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);
            Vector128<byte> l0 = Vector128.LoadUnsafe(ref leftRef);
            Vector128<byte> r0 = Vector128.LoadUnsafe(ref rightRef);
            Vector128<byte> t0 = l0 ^ r0;
            Vector128<byte> l1 = Vector128.LoadUnsafe(ref leftRef, len - (uint)Vector128<byte>.Count);
            Vector128<byte> r1 = Vector128.LoadUnsafe(ref rightRef, len - (uint)Vector128<byte>.Count);
            Vector128<byte> t1 = l1 ^ r1;
            Vector128<byte> t = t0 | t1;
            return t == Vector128<byte>.Zero;
        }
        if (Vector256.IsHardwareAccelerated && len >= (uint)Vector256<byte>.Count && len <= 2 * (uint)Vector256<byte>.Count)
        {
            ref byte leftRef = ref MemoryMarshal.GetReference(left);
            ref byte rightRef = ref MemoryMarshal.GetReference(right);
            Vector256<byte> l0 = Vector256.LoadUnsafe(ref leftRef);
            Vector256<byte> r0 = Vector256.LoadUnsafe(ref rightRef);
            Vector256<byte> t0 = l0 ^ r0;
            Vector256<byte> l1 = Vector256.LoadUnsafe(ref leftRef, len - (uint)Vector256<byte>.Count);
            Vector256<byte> r1 = Vector256.LoadUnsafe(ref rightRef, len - (uint)Vector256<byte>.Count);
            Vector256<byte> t1 = l1 ^ r1;
            Vector256<byte> t = t0 | t1;
            return t == Vector256<byte>.Zero;
        }
    }
    // Current implementation of SequenceEqualCore w/o length-check (already done)
    return SequenceEqualCore(left, right);
}So it's more IL and more work for the JIT to do. Are these the reasons why it's done via  | 
Co-authored-by: Jakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
| @gfoidl there are two issues with the managed approach: 
 | 
| } | ||
| else if (strcmp(className, "SpanHelpers") == 0) | ||
| { | ||
| if (strcmp(methodName, "SequenceEqual") == 0) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is completely unrelated to the fix or changes being done, sorry... but has anyone ever tried reversing the methodName and className tests for a performance hack in the JIT itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IDisposable lookupNamedIntrinsics never show up in our JIT traces so we don't bother. This code is only executed for methods with [Intrinsic] attribute so for 99% of methods it doesn't kick in.
We could use here a Trie/binary search if it was a real problem
| @EgorBo thanks for the info, I understand. | 
        
          
                src/coreclr/jit/lower.cpp
              
                Outdated
          
        
      | // Call LowerNode on these to create addressing modes if needed | ||
| LowerNode(l2Indir); | ||
| LowerNode(r2Indir); | ||
| LowerNode(lXor); | ||
| LowerNode(rXor); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like you could just make this function return the first new node you added, since the call was replaced anyway, and have "normal" lowering proceed from there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, done
| LIR::Use lArgUse; | ||
| LIR::Use rArgUse; | ||
| bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse); | ||
| bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse); | ||
| assert(lFoundUse && rFoundUse); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit wasteful to go looking for the uses of this given that we know the arg they come from. E.g. you could do
| LIR::Use lArgUse; | |
| LIR::Use rArgUse; | |
| bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse); | |
| bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse); | |
| assert(lFoundUse && rFoundUse); | |
| CallArg* lArg = call->gtArgs.GetUserArgByIndex(0); | |
| GenTree*& lArgNode = lArg->GetLateNode() == nullptr ? lArg->EarlyNodeRef() : lArg->LateNodeRef(); | |
| ... | |
| LIR::Use lArgUse(BlockRange(), &lArgNode, call); | 
I don't have a super strong opinion on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank, will check in a follow up once SPMI is collected - want to see if it's worth the effort to improve this expansion. jit-diff utils found around 30 methods only
Co-authored-by: Jakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
Unroll
SequenceEqualfor constant length[1..15](will add SIMD separately if this lands) for both x64 and arm64.Example (utf8 literal):
Codegen diff: https://www.diffchecker.com/E1laymuB/
Limitations
Unfortunately, it works only when a constant span (either RVA or e.g.
data.Slice(0, 10)) is on the left. It happens because we use left span's Length here:runtime/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs
Line 1437 in 922c75c
In theory, JIT is smart enough to perform things like:
via AssertProp, but in this case it's a bit more complicated than that. Perhaps, we can assist it with
IsKnowConstant. Or we can use RHS span's length instead if we think that a constant span is more likely to appear on the right side.Works for
StartsWith.Motivation
Mainly, these comparisons in TE.
Benchmarks
(the difference should be bigger when SIMD is enabled)