Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special case List<ClaimsIdentity> in SelectPrimaryIdentity #111799

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,31 @@ protected ClaimsPrincipal(SerializationInfo info, StreamingContext context)
{
ArgumentNullException.ThrowIfNull(identities);

foreach (ClaimsIdentity identity in identities)
// If the identities value is exactly a List<ClaimsIdentity>, special case it so that
// the enumerator allocation can be skipped. Doing this for List<ClaimsIdentity> is the 99%
// case because it is normally used on the _identities value, which is a List<ClaimsIdentity>.
if (identities.GetType() == typeof(List<ClaimsIdentity>))
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
{
if (identity != null)
List<ClaimsIdentity> identitiesList = (identities as List<ClaimsIdentity>)!;

for (int i = 0; i < identitiesList.Count; i++)
{
ClaimsIdentity identity = identitiesList[i];

if (identity != null)
{
return identity;
}
}
}
else
{
foreach (ClaimsIdentity identity in identities)
{
return identity;
if (identity != null)
{
return identity;
}
}
}

Expand Down
48 changes: 48 additions & 0 deletions src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -242,6 +243,53 @@ public void Current_FallsBackToThread_UnauthenticatedPrincipalPolicy()
}).Dispose();
}

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void PrimaryIdentitySelector_Default()
{
RemoteExecutor.Invoke(static () =>
{
ClaimsIdentity identity0 = null;
ClaimsIdentity identity1 = new([new Claim("type", "value")]);
ClaimsIdentity identity2 = new([new Claim("type", "value")]);
IEnumerable<ClaimsIdentity> identities = [identity0, identity1, identity2];
Func<IEnumerable<ClaimsIdentity>, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector;

Assert.Same(identity1, selector(identities));
Assert.Null(selector([]));
AssertExtensions.Throws<ArgumentNullException>("identities", () => selector(null));
}).Dispose();
}

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void PrimaryIdentitySelector_DefaultOnlySpecialCasesList()
{
RemoteExecutor.Invoke(static () =>
{
ClaimsIdentity identity0 = null;
ClaimsIdentity identity1 = new([new Claim("type", "value")]);
ClaimsIdentity identity2 = new([new Claim("type", "value")]);
ClaimsIdentityList identities = [identity0, identity1, identity2];
Func<ClaimsIdentityList, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector;

Assert.Same(identity1, selector(identities));
Assert.Equal(1, identities.GetEnumeratorCount);
Assert.Null(selector(new ClaimsIdentityList()));
}).Dispose();
}

private sealed class ClaimsIdentityList : List<ClaimsIdentity>, IEnumerable<ClaimsIdentity>
{
private readonly List<ClaimsIdentity> _claimsIdentities = [];

public int GetEnumeratorCount { get; private set; }

public new IEnumerator<ClaimsIdentity> GetEnumerator()
{
GetEnumeratorCount++;
return base.GetEnumerator();
}
}

private class NonClaimsPrincipal : IPrincipal
{
public IIdentity Identity { get; set; }
Expand Down
Loading