Snippets

Stefan Glienke Custom record equalitycomparer

Created by Stefan Glienke
program BetterRecordComparer;

{$APPTYPE CONSOLE}
{$O+,W-}

uses
  Generics.Defaults, Rtti, TypInfo, Hash;

function GetEqualsOperator(const typeInfo: PTypeInfo): Pointer;
const
  EqualsOperatorName = '&op_Equality';
var
  ctx: TRttiContext;
  method: TRttiMethod;
  parameters: TArray<TRttiParameter>;
begin
  for method in ctx.GetType(typeInfo).GetMethods(EqualsOperatorName) do
  begin
    if method.MethodKind <> mkOperatorOverload then
      Continue;
    if method.CallingConvention <> ccReg then
      Continue;
    parameters := method.GetParameters;
    if (Length(parameters) = 2)
      and (parameters[0].ParamType.Handle = typeInfo) and (parameters[1].ParamType.Handle = typeInfo)
      and (pfConst in parameters[0].Flags) and (pfConst in parameters[1].Flags) then
     Exit(method.CodeAddress);
  end;
  Result := nil;
end;

function GetGetHashCode(const typeInfo: PTypeInfo): Pointer;
var
  ctx: TRttiContext;
  method: TRttiMethod;
begin
  for method in ctx.GetType(typeInfo).GetMethods('GetHashCode') do
  begin
    if method.MethodKind <> mkFunction then
      Continue;
    if method.CallingConvention <> ccReg then
      Continue;
    if method.ReturnType.Handle <> System.TypeInfo(Integer) then
      Continue;
    if method.GetParameters = nil then
      Exit(method.CodeAddress);
  end;
  Result := nil;
end;

type
  TMyRec = record
    value: string;
    class operator Implicit(const value: string): TMyRec;

    class operator Equal(const left, right: TMyRec): Boolean;
    function GetHashCode: Integer;
  end;

  TEqualsOperator = function(const left, right): Boolean;
  TGetHashCode = function(self: Pointer): Integer;

  PComparerInstance = ^TComparerInstance;
  TComparerInstance = record
    Vtable: Pointer;
    RefCount: Integer;
    Size: Integer;
    Equals: TEqualsOperator;
    GetHashCode: TGetHashCode;
  end;

function NopQueryInterface(inst: Pointer; const IID: TGUID; out Obj): HResult; stdcall;
begin
  Result := E_NOINTERFACE;
end;

function MemAddref(inst: PComparerInstance): Integer; stdcall;
begin
  Result := AtomicIncrement(inst^.RefCount);
end;

function MemRelease(inst: PComparerInstance): Integer; stdcall;
begin
  Result := AtomicDecrement(inst^.RefCount);
  if Result = 0 then
    FreeMem(inst);
end;

function Equals_Method(inst: PComparerInstance; const left, right): Boolean;
begin
  Result := inst^.Equals(left, right);
end;

function GetHashCode_Method(inst: PComparerInstance; value: Pointer): Integer;
begin
  if inst.size <= 4 then // check for 64bit
    Result := inst^.GetHashCode(@value)
  else
    Result := inst^.GetHashCode(value);
end;

const
  EqualityComparer_Vtable_Method: array[0..4] of Pointer =
  (
    @NopQueryInterface,
    @MemAddref,
    @MemRelease,
    @Equals_Method,
    @GetHashCode_Method
  );

function MakeInstance(vtable: Pointer; size: Integer;
  equals: TEqualsOperator; getHashCode: TGetHashCode): Pointer;
var
  inst: PComparerInstance;
begin
  GetMem(inst, SizeOf(inst^));
  inst^.Vtable := vtable;
  inst^.RefCount := 0;
  inst^.Size := size;
  inst^.Equals := equals;
  inst^.GetHashCode := getHashCode;
  Result := inst;
end;

function _LookupVtableInfo(intf: TDefaultGenericInterface; info: PTypeInfo; size: Integer): Pointer;
var
  equalsMethod, getHashCodeMethod: Pointer;
begin
  Result := nil;
  if (intf = giEqualityComparer) and (info.Kind = tkRecord) then
  begin
    equalsMethod := GetEqualsOperator(info);
    getHashCodeMethod := GetGetHashCode(info);
    if Assigned(equalsMethod) and Assigned(getHashCodeMethod) then
      Result := MakeInstance(@EqualityComparer_Vtable_Method, size, equalsMethod, getHashCodeMethod);
  end;
  if not Assigned(Result) then
    Result := Generics.Defaults._LookupVtableInfo(intf, info, size);
end;


{ TMyRec }

class operator TMyRec.Equal(const left, right: TMyRec): Boolean;
begin
  Result := left.value = right.value;
end;

function TMyRec.GetHashCode: Integer;
begin
  Result := THashBobJenkins.GetHashValue(Value[Low(string)], Length(value) * SizeOf(Char), 0);
end;

class operator TMyRec.Implicit(const value: string): TMyRec;
begin
  Result.value := value;
end;

var
  c: IEqualityComparer<TMyRec>;
  r1, r2: TMyRec;
begin
  r1 := 'a';
  r2 := 'a';

  c := IEqualityComparer<TMyRec>(_LookupVtableInfo(giEqualityComparer, TypeInfo(TMyRec), SizeOf(TMyRec)));
  Writeln(c.Equals(r1, r2));
  Writeln(r1.GetHashCode);
  Writeln(c.GetHashCode(r1));
  Writeln(c.GetHashCode(r2));
end.

Comments (0)

HTTPS SSH

You can clone a snippet to your computer for local editing. Learn more.