Snippets

Stefan Glienke Custom record equalitycomparer

Created by Stefan Glienke

File snippet.txt Added

  • Ignore whitespace
  • Hide word diff
+program BetterRecordComparer;
+
+{$APPTYPE CONSOLE}
+{$O+,W-}
+
+uses
+  Generics.Defaults, Rtti, TypInfo, Hash;
+
+function GetEqualsOperator(const typeInfo: PTypeInfo): Pointer;
+const
+  EqualsOperatorName = '&op_Equality';
+var
+  ctx: TRttiContext;
+  method: TRttiMethod;
+  parameters: TArray<TRttiParameter>;
+begin
+  for method in ctx.GetType(typeInfo).GetMethods(EqualsOperatorName) do
+  begin
+    if method.MethodKind <> mkOperatorOverload then
+      Continue;
+    if method.CallingConvention <> ccReg then
+      Continue;
+    parameters := method.GetParameters;
+    if (Length(parameters) = 2)
+      and (parameters[0].ParamType.Handle = typeInfo) and (parameters[1].ParamType.Handle = typeInfo)
+      and (pfConst in parameters[0].Flags) and (pfConst in parameters[1].Flags) then
+     Exit(method.CodeAddress);
+  end;
+  Result := nil;
+end;
+
+function GetGetHashCode(const typeInfo: PTypeInfo): Pointer;
+var
+  ctx: TRttiContext;
+  method: TRttiMethod;
+begin
+  for method in ctx.GetType(typeInfo).GetMethods('GetHashCode') do
+  begin
+    if method.MethodKind <> mkFunction then
+      Continue;
+    if method.CallingConvention <> ccReg then
+      Continue;
+    if method.ReturnType.Handle <> System.TypeInfo(Integer) then
+      Continue;
+    if method.GetParameters = nil then
+      Exit(method.CodeAddress);
+  end;
+  Result := nil;
+end;
+
+type
+  TMyRec = record
+    value: string;
+    class operator Implicit(const value: string): TMyRec;
+
+    class operator Equal(const left, right: TMyRec): Boolean;
+    function GetHashCode: Integer;
+  end;
+
+  TEqualsOperator = function(const left, right): Boolean;
+  TGetHashCode = function(self: Pointer): Integer;
+
+  PComparerInstance = ^TComparerInstance;
+  TComparerInstance = record
+    Vtable: Pointer;
+    RefCount: Integer;
+    Size: Integer;
+    Equals: TEqualsOperator;
+    GetHashCode: TGetHashCode;
+  end;
+
+function NopQueryInterface(inst: Pointer; const IID: TGUID; out Obj): HResult; stdcall;
+begin
+  Result := E_NOINTERFACE;
+end;
+
+function MemAddref(inst: PComparerInstance): Integer; stdcall;
+begin
+  Result := AtomicIncrement(inst^.RefCount);
+end;
+
+function MemRelease(inst: PComparerInstance): Integer; stdcall;
+begin
+  Result := AtomicDecrement(inst^.RefCount);
+  if Result = 0 then
+    FreeMem(inst);
+end;
+
+function Equals_Method(inst: PComparerInstance; const left, right): Boolean;
+begin
+  Result := inst^.Equals(left, right);
+end;
+
+function GetHashCode_Method(inst: PComparerInstance; value: Pointer): Integer;
+begin
+  if inst.size <= 4 then // check for 64bit
+    Result := inst^.GetHashCode(@value)
+  else
+    Result := inst^.GetHashCode(value);
+end;
+
+const
+  EqualityComparer_Vtable_Method: array[0..4] of Pointer =
+  (
+    @NopQueryInterface,
+    @MemAddref,
+    @MemRelease,
+    @Equals_Method,
+    @GetHashCode_Method
+  );
+
+function MakeInstance(vtable: Pointer; size: Integer;
+  equals: TEqualsOperator; getHashCode: TGetHashCode): Pointer;
+var
+  inst: PComparerInstance;
+begin
+  GetMem(inst, SizeOf(inst^));
+  inst^.Vtable := vtable;
+  inst^.RefCount := 0;
+  inst^.Size := size;
+  inst^.Equals := equals;
+  inst^.GetHashCode := getHashCode;
+  Result := inst;
+end;
+
+function _LookupVtableInfo(intf: TDefaultGenericInterface; info: PTypeInfo; size: Integer): Pointer;
+var
+  equalsMethod, getHashCodeMethod: Pointer;
+begin
+  Result := nil;
+  if (intf = giEqualityComparer) and (info.Kind = tkRecord) then
+  begin
+    equalsMethod := GetEqualsOperator(info);
+    getHashCodeMethod := GetGetHashCode(info);
+    if Assigned(equalsMethod) and Assigned(getHashCodeMethod) then
+      Result := MakeInstance(@EqualityComparer_Vtable_Method, size, equalsMethod, getHashCodeMethod);
+  end;
+  if not Assigned(Result) then
+    Result := Generics.Defaults._LookupVtableInfo(intf, info, size);
+end;
+
+
+{ TMyRec }
+
+class operator TMyRec.Equal(const left, right: TMyRec): Boolean;
+begin
+  Result := left.value = right.value;
+end;
+
+function TMyRec.GetHashCode: Integer;
+begin
+  Result := THashBobJenkins.GetHashValue(Value[Low(string)], Length(value) * SizeOf(Char), 0);
+end;
+
+class operator TMyRec.Implicit(const value: string): TMyRec;
+begin
+  Result.value := value;
+end;
+
+var
+  c: IEqualityComparer<TMyRec>;
+  r1, r2: TMyRec;
+begin
+  r1 := 'a';
+  r2 := 'a';
+
+  c := IEqualityComparer<TMyRec>(_LookupVtableInfo(giEqualityComparer, TypeInfo(TMyRec), SizeOf(TMyRec)));
+  Writeln(c.Equals(r1, r2));
+  Writeln(r1.GetHashCode);
+  Writeln(c.GetHashCode(r1));
+  Writeln(c.GetHashCode(r2));
+end.
HTTPS SSH

You can clone a snippet to your computer for local editing. Learn more.