Commits

Jason R. Coombs  committed d698058

Added helper methods to retrieve parts from a package by content type and by class

  • Participants
  • Parent commits 4bd8d93

Comments (0)

Files changed (2)

File openpack/basepack.py

 		if ct is None:
 			log.warning('no content type found for part %(name)s' % vars())
 			return
-		part = Part(self, name, ct, data=data)
+		part = Part(self, name, data=data)
 		self[name] = part
 
 	@handle('http://schemas.openxmlformats.org/package/2006/relationships/metadata/core-properties')
 	def __repr__(self):
 		return "Package-%s" % id(self)
 
+	def get_parts_by_class(self, cls):
+		"""
+		Return all parts of this package that are instances of cls
+		(where cls is passed directly to isinstance, so can be a class
+		or sequence of classes).
+		"""
+		return (part for part in self.parts.values() if isinstance(part, cls))
+
+	def get_parts_by_content_type(self, content_type):
+		# first find any parts who's registered type matches or who's
+		#  content_type attribute matches
+		return (
+			part
+			for part in self.parts.values()
+			if self.content_types.find_for(part.name) == content_type
+			or part.content_type == content_type
+			)
+
 class Part(Relational):
 	"""Parts are the building blocks of OOXML files.
 
 	content_type = None
 	rel_type = None
 
-	def __init__(self, package, name, growth_hint=None, data=None):
+	def __init__(self, package, name, content_type=None, rel_type=None, growth_hint=None, data=None):
 		self.name = name
 		self.package = package
+		if content_type is not None:
+			self.content_type = content_type
+		if rel_type is not None:
+			self.rel_type = rel_type
 		self.growth_hint = growth_hint
 		if not isinstance(self, Relationships):
 			self.relationships = Relationships(self.package, self)

File test/test_basepack.py

 		p = SamplePart(self.pack, '/pmx/samp.vpart')
 		self.pack.add(p, override=False)
 		ct = self.pack.content_types.find_for('/pmx/samp.vpart')
+		assert isinstance(ct, ContentType.Default)
 		assert ct is not None
+		assert ct.key == 'vpart'
+		assert ct.name == p.content_type
 		
 	def test_add_no_override(self):
 		self.test_create()
-		p = SamplePart(self.pack, '/pmx/samp.main')
-		p.content_type = "app/pmxmain+xml"
+		p = SamplePart(self.pack, '/pmx/samp.main', content_type='app/pmxmain+xml')
 		self.pack.add(p)
 		ct = self.pack.content_types.find_for('/pmx/samp.main') 
+		assert isinstance(ct, ContentType.Override)
 		assert ct is not None
-		assert ct.name == 'app/pmxmain+xml'
+		assert ct.key == p.name
+		assert ct.name == p.content_type
+
+	def test_get_parts_by_content_type(self):
+		pack = Package()
+		part = SamplePart(pack, '/pmx/samp.main')
+		pack.add(part)
+		parts = pack.get_parts_by_content_type(part.content_type)
+		assert parts.next() is part
+		py.test.raises(StopIteration, parts.next)
+		ct = pack.content_types.find_for(part.name)
+		parts = pack.get_parts_by_content_type(ct)
+		assert parts.next() is part
+		py.test.raises(StopIteration, parts.next)
+
+	def test_get_parts_by_class(self):
+		pack = Package()
+		part = SamplePart(pack, '/pmx/samp.main')
+		pack.add(part)
+		parts = pack.get_parts_by_class(SamplePart)
+		assert parts.next() is part
+		py.test.raises(StopIteration, parts.next)
 
 class TestContentTypes:
 	def test_no_duplicates_in_output(self):