Snippets

Stefan Glienke Record comparer

Created by Stefan Glienke last modified
unit RecComparer;

interface

uses
  Generics.Defaults;

type
  PComparerData = ^TComparerData;
  TComparerData = record // same layout as Generics.Defaults.TSimpleInstance
    Vtable: TArray<Pointer>;
    RefCount: Integer;
    Size: Integer;
    Default: IInterface;
    Equals: Pointer;
    GetHashCode: Pointer;
    constructor Create(typeSize: Integer; typeInfo: Pointer;
      const defaultComparer: IInterface;
      proxyEquals, proxyGetHashCode: Pointer);
  end;

  TRecordEqualityComparer<T{: record}> = class
  strict private class var
    Instance: TComparerData;
    class function Equals(inst: PComparerData; const left, right: T): Boolean; reintroduce; static;
    class function GetHashCode(inst: PComparerData; const value: T): Integer; reintroduce; static;
  public
    class constructor Create;
    class function Default: IEqualityComparer<T>; static;
  end;

implementation

uses
  Rtti,
  SysUtils,
  TypInfo;

function NopQueryInterface(inst: Pointer; const IID: TGUID; out Obj): HResult; stdcall;
begin
  Result := E_NOINTERFACE;
end;

function NopAddRef(inst: Pointer): Integer; stdcall;
begin
  Result := -1;
end;

function NopRelease(inst: Pointer): Integer; stdcall;
begin
  Result := -1;
end;

constructor TComparerData.Create(typeSize: Integer; typeInfo: Pointer;
  const defaultComparer: IInterface; proxyEquals, proxyGetHashCode: Pointer);
type
  PPVtable = ^PVtable;
  PVtable = ^TVtable;
  TVtable = array[0..4] of Pointer;
var
  ctx: TRttiContext;
  method: TRttiMethod;
  returnType: Pointer;
  params: TArray<TRttiParameter>;
  isStatic: Boolean;
begin
  Default := defaultComparer;
  RefCount := 0;
  Size := typeSize;

  SetLength(Vtable, 5);
  Vtable[0] := @NopQueryInterface;
  Vtable[1] := @NopAddRef;
  Vtable[2] := @NopRelease;
  Vtable[3] := PPVTable(defaultComparer)^^[3];
  Vtable[4] := PPVTable(defaultComparer)^^[4];

  // TODO: possibly use low level RTTI to be quicker and not have dependency on Rtti.pas
  for method in ctx.GetType(typeInfo).GetMethods do
  begin
    if Assigned(GetHashCode) and Assigned(Equals) then
      Break;
    if method.ReturnType = nil then
      Continue;
    returnType := method.ReturnType.Handle;
    isStatic := method.IsStatic;
    if (returnType = System.TypeInfo(Boolean)) and isStatic
      and SameText(method.Name, '&op_Equality') then
    begin
      params := method.GetParameters;
      if (Length(params) = 2)
        and (params[0].ParamType.Handle = typeInfo)
        and (pfConst in params[0].Flags)
        and (params[1].ParamType.Handle = typeInfo)
        and (pfConst in params[0].Flags) then
      begin
        Equals := method.CodeAddress;
        Vtable[3] := proxyEquals;
        Continue;
      end;
    end;
    if (returnType = System.TypeInfo(Integer)) and not isStatic
      and SameText(method.Name, 'GetHashCode') then
    begin
      params := method.GetParameters;
      if Length(params) = 0 then
      begin
        GetHashCode := method.CodeAddress;
        Vtable[4] := proxyGetHashCode;
        Continue;
      end;
    end;
  end;
end;

{ TRecordEqualityComparer<T> }

class constructor TRecordEqualityComparer<T>.Create;
begin
  Instance := TComparerData.Create(SizeOf(T), TypeInfo(T),
    TEqualityComparer<T>.Default, @Equals, @GetHashCode);
end;

class function TRecordEqualityComparer<T>.Default: IEqualityComparer<T>;
begin
  Assert(GetTypeKind(T) = tkRecord);

  IInterface(Result) := nil;
  Pointer(Result) := @Instance;
end;

class function TRecordEqualityComparer<T>.Equals(inst: PComparerData;
  const left, right: T): Boolean;
type
  TEquals = function(const left, right: T): Boolean;
begin
  Result := TEquals(inst.Equals)(left, right);
end;

class function TRecordEqualityComparer<T>.GetHashCode(
  inst: PComparerData; const value: T): Integer;
type
  TGetHashCode = function(Self: Pointer): Integer;
begin
  Result := TGetHashCode(inst.GetHashCode)(@value);
end;

end.

Comments (0)