Snippets

Sapu94 TSM Class Library

Created by Sapu Tradeskillmaster

File TSMClassLibrary.lua Added

  • Ignore whitespace
  • Hide word diff
+-- ------------------------------------------------------------------------------ --
+--                                TradeSkillMaster                                --
+--                          https://tradeskillmaster.com                          --
+--    All Rights Reserved* - Detailed license information included with addon.    --
+-- ------------------------------------------------------------------------------ --
+
+local private = { classInfo = {}, instInfo = {} }
+-- Set the keys as weak so that instances of classes can be GC'd (classes are never GC'd)
+setmetatable(private.instInfo, { __mode = "k" })
+local SPECIAL_PROPERTIES = {
+	__init = true,
+	__tostring = true,
+	__class = true,
+	__isa = true,
+	__super = true,
+	__name = true,
+}
+local DEFAULT_INST_FIELDS = {
+	__init = function(self)
+		-- do nothing
+	end,
+	__tostring = function(self)
+		return private.instInfo[self].str
+	end,
+}
+
+
+
+-- ============================================================================
+-- Global API Functions
+-- ============================================================================
+
+function TSMClassLibary_DefineClass(name, superclass, ...)
+	assert(type(name) == "string", "Invalid class name: "..tostring(name), 1)
+	local abstract = false
+	for i = 1, select('#', ...) do
+		local modifier = select(i, ...)
+		if modifier == "ABSTRACT" then
+			abstract = true
+		else
+			error("Invalid modifier: "..tostring(modifier))
+		end
+	end
+
+	local class = setmetatable({}, private.CLASS_MT)
+	private.classInfo[class] = {
+		name = name,
+		static = {},
+		superStatic = {},
+		superclass = superclass,
+		abstract = abstract,
+	}
+	while superclass do
+		for key, value in pairs(private.classInfo[superclass].static) do
+			if not private.classInfo[class].superStatic[key] then
+				private.classInfo[class].superStatic[key] = { class = superclass, value = value }
+			end
+		end
+		superclass = superclass.__super
+	end
+	return class
+end
+
+
+
+-- ============================================================================
+-- Instance Metatable
+-- ============================================================================
+
+private.INST_MT = {
+	__newindex = function(self, key, value)
+		assert(key ~= "__super" and key ~= "__isa" and key ~= "__class", "Can't set reserved key: "..tostring(key))
+		if private.classInfo[self.__class].static[key] ~= nil then
+			private.classInfo[self.__class].static[key] = value
+		elseif not private.instInfo[self].hasSuperclass then
+			-- we just set this directly on the instance table for better performance
+			rawset(self, key, value)
+		else
+			private.instInfo[self].fields[key] = value
+		end
+	end,
+	__index = function(self, key)
+		-- This method is super optimized since it's used for every class instance access, meaning function calls and
+		-- table lookup is kept to an absolute minimum, at the expense of readability and code reuse.
+		local instInfo = private.instInfo[self]
+
+		-- check if this key is an instance field first, since this is the most common case
+		local res = instInfo.fields[key]
+		if res ~= nil then
+			instInfo.currentClass = nil
+			return res
+		end
+
+		-- check if it's the special __super field
+		if key == "__super" then
+			-- The class of the current class method we are in, or nil if we're not in a class method.
+			local methodClass = instInfo.methodClass
+			-- We can only access the superclass within a class method and will use the class which defined that method
+			-- as the base class to jump to the superclass of, regardless of what class the instance actually is.
+			if not methodClass then
+				error("The superclass can only be referenced within a class method.")
+			end
+			instInfo.currentClass = private.classInfo[instInfo.currentClass or methodClass].superclass
+			if not instInfo.currentClass then
+				error("No super class found.")
+			end
+			return self
+		end
+
+		-- reset the current class since we're not continuing the __super chain
+		local class = instInfo.currentClass or instInfo.class
+		instInfo.currentClass = nil
+
+		-- check if this is a static key
+		local classInfo = private.classInfo[class]
+		res = classInfo.static[key]
+		if res ~= nil then
+			return res
+		end
+
+		-- check if it's a static field in the superclass
+		local superStaticRes = classInfo.superStatic[key]
+		if superStaticRes then
+			local superclass = superStaticRes.class
+			local superclassInfo = private.classInfo[superclass]
+			res = superStaticRes.value
+			return res
+		end
+
+		-- check if this field has a default value
+		res = DEFAULT_INST_FIELDS[key]
+		if res ~= nil then
+			return res
+		end
+
+		return nil
+	end,
+	__tostring = function(self)
+		return self:__tostring()
+	end,
+	__metatable = false,
+}
+
+
+
+-- ============================================================================
+-- Class Metatable
+-- ============================================================================
+
+private.CLASS_MT = {
+	__newindex = function(self, key, value)
+		assert(not private.classInfo[self].static[key], "Can't modify or override static members")
+		assert(key ~= "__super" and key ~= "__isa" and key ~= "__class", "Reserved word: "..key)
+		if type(value) == "function" then
+			-- We wrap class methods so that within them, the instance appears to be of the defining class
+			private.classInfo[self].static[key] = function(inst, ...)
+				local instInfo = private.instInfo[inst]
+				if not instInfo.isClassLookup[self] then
+					error(format("Attempt to call class method on non-object (%s)!", tostring(inst)))
+				end
+				if not instInfo.hasSuperclass then
+					-- don't need to worry about methodClass so just call the function directly
+					return value(inst, ...)
+				else
+					local prevMethodClass = instInfo.methodClass
+					instInfo.methodClass = self
+					return private.InstMethodReturnHelper(prevMethodClass, instInfo, value(inst, ...))
+				end
+			end
+		else
+			private.classInfo[self].static[key] = value
+		end
+	end,
+	__index = function(self, key)
+		-- check if it's the special __isa method which all classes implicitly have
+		if key == "__isa" then
+			return private.ClassIsA
+		elseif key == "__name" then
+			return private.classInfo[self].name
+		elseif key == "__super" then
+			return private.classInfo[self].superclass
+		end
+		error("Class type is write-only")
+	end,
+	__tostring = function(self)
+		return "class:"..private.classInfo[self].name
+	end,
+	__call = function(self, ...)
+		assert(not private.classInfo[self].abstract, "Attempting to instantiate an abstract class!")
+		-- Create a new instance of this class
+		local inst = {}
+		local instStr = strmatch(tostring(inst), "table:[^0-9a-fA-F]*([0-9a-fA-F]+)")
+		setmetatable(inst, private.INST_MT)
+		local hasSuperclass = private.classInfo[self].superclass and true or false
+		private.instInfo[inst] = {
+			class = self,
+			fields = {
+				__class = self,
+				__isa = private.InstIsA,
+			},
+			str = private.classInfo[self].name..":"..instStr,
+			isClassLookup = {},
+			hasSuperclass = hasSuperclass,
+		}
+		if not hasSuperclass then
+			-- set the static members directly on this object for better performance
+			for key, value in pairs(private.classInfo[self].static) do
+				if not SPECIAL_PROPERTIES[key] then
+					rawset(inst, key, value)
+				end
+			end
+		end
+		local c = self
+		while c do
+			private.instInfo[inst].isClassLookup[c] = true
+			c = private.classInfo[c].superclass
+		end
+		assert(select("#", inst:__init(...)) == 0, "__init must not return any values")
+		return inst
+	end,
+	__metatable = false,
+}
+
+
+
+-- ============================================================================
+-- Helper Functions
+-- ============================================================================
+
+function private.InstMethodReturnHelper(class, instInfo, ...)
+	-- reset methodClass now that the function returned
+	instInfo.methodClass = class
+	return ...
+end
+
+function private.InstIsA(inst, targetClass)
+	return private.instInfo[inst].isClassLookup[targetClass]
+end
+
+function private.ClassIsA(class, targetClass)
+	while class do
+		if class == targetClass then return true end
+		class = class.__super
+	end
+end

File TSMClassUnitTest.lua Added

  • Ignore whitespace
  • Hide word diff
+TestClass = {}
+function TestClass:TestBasic()
+	local Test = TSMClassLibary_DefineClass("Test")
+	function Test.__init(self)
+		self.initialized = true
+	end
+	function Test.GetMagicNumber(self)
+		return 0
+	end
+
+	local testInst = Test()
+	luaunit.assertTrue(testInst.initialized)
+	luaunit.assertEquals(testInst:GetMagicNumber(), 0)
+end
+
+function TestClass:TestSubClass()
+	local Test = TSMClassLibary_DefineClass("Test")
+	function Test.__init(self)
+		self.initialized = true
+		self.n = 2
+	end
+	function Test.GetMagicNumber(self)
+		return 0
+	end
+	function Test.Echo(self, ...)
+		return ...
+	end
+
+	local TestSub = TSMClassLibary_DefineClass("TestSub", Test)
+	function TestSub.__init(self)
+		self.__super:__init()
+		self.subInitialized = true
+	end
+	function TestSub.GetMagicNumber(self)
+		return self.__super:GetMagicNumber() + 1
+	end
+	function TestSub.GetText(self)
+		return "TEXT"
+	end
+
+	luaunit.assertTrue(Test:__isa(Test))
+	luaunit.assertTrue(TestSub:__isa(Test))
+	luaunit.assertTrue(TestSub:__isa(TestSub))
+
+	local testSubInst = TestSub()
+	luaunit.assertTrue(testSubInst:__isa(Test))
+	luaunit.assertTrue(testSubInst:__isa(TestSub))
+	luaunit.assertTrue(testSubInst.initialized)
+	luaunit.assertTrue(testSubInst.subInitialized)
+
+	luaunit.assertEquals(testSubInst.n, 2)
+	testSubInst.n = testSubInst.n + 1
+	luaunit.assertEquals(testSubInst.n, 3)
+
+	luaunit.assertEquals(testSubInst:GetMagicNumber(), 1)
+	luaunit.assertEquals(testSubInst:Echo(22), 22)
+	luaunit.assertEquals(testSubInst:GetText(), "TEXT")
+end
+
+function TestClass:TestVirtual()
+	local Test = TSMClassLibary_DefineClass("Test")
+	function Test.TestVirtual(self)
+		return 111
+	end
+	function Test.TestVirtualCaller(self)
+		return self:TestVirtual()
+	end
+	function Test.TestVirtualCaller2(self)
+		return self:TestVirtual2()
+	end
+
+	local TestSub = TSMClassLibary_DefineClass("TestSub", Test)
+	function TestSub.TestVirtual(self)
+		return 777
+	end
+	function TestSub.TestVirtual2(self)
+		return 333
+	end
+
+	local testSubInst = TestSub()
+	luaunit.assertEquals(testSubInst:TestVirtual(), 777)
+	luaunit.assertEquals(testSubInst:TestVirtualCaller(), 777)
+	luaunit.assertEquals(testSubInst:TestVirtual2(), 333)
+	luaunit.assertEquals(testSubInst:TestVirtualCaller2(), 333)
+end
HTTPS SSH

You can clone a snippet to your computer for local editing. Learn more.