Commits

Olivier Grisel committed 1a55de4

temporary hack to work with wikiformat pictures

  • Participants
  • Parent commits f42c146
  • Branches commons-mediawiki-format

Comments (0)

Files changed (1)

File src/sgd/hashing/image.py

     for j, filename in enumerate(filenames):
         try:
             im = Image.open(os.path.join(image_folder, filename))
+            # crop to max size centered square
+            max = min(im.size)
+            center = (im.size[0] / 2, im.size[1] / 2)
+            im = im.crop(box = (center[0] - max / 2, center[1] - max / 2,
+                                center[0] + max / 2, center[1] + max / 2))
+
+
             # GIST is too slow to compute on big pictures
-            im.thumbnail((256, 256))
+            im.thumbnail((150, 150))
             features[1][j] = leargist.color_gist(im)
             noised_im = extractor.noisify(im)
             features[0][j] = leargist.color_gist(noised_im)
                                                         validation_size=5)
 
         autoencoders = make_stacked_autoencoders(
-            (960, 500, 128), 4, lambdas=[1e-3, 1e-3, 1e-3, 5e-4])
+            (960, 600, 300, 150, 64), 4, lambdas=[1e-3, 1e-3, 1e-3, 5e-4])
 #        autoencoders[0].modules[1].set_hyper_parameters(hp_1=0.6)
 #        autoencoders[0].modules[3].set_hyper_parameters(hp_1=0.6)
 #        autoencoders[1].modules[1].set_hyper_parameters(hp_1=1.2)
         complete_chunker = GistFeaturesChunker(gist_folder, validation_size=0)
         for i, aec in enumerate(autoencoders):
             logging.info("training autoencoder stage #%d", i)
-            aec.train_from_chunker(training_chunker, epochs=40,
-                                   validation_interval=5,
-                                   improvement_tol=1e-5)
+            #aec.train_from_chunker(training_chunker, epochs=40,
+            #                       validation_interval=5,
+            #                       improvement_tol=1e-5)
             enc = extract_encoder([aec])
             training_chunker = EncodingChunker(enc, training_chunker)
             complete_chunker = EncodingChunker(enc, complete_chunker)
 
+        cat_chunker = GistImageCaterogiesChunker(
+            gist_folder, validation_size=2, category_pattern=r"([-_\w]+)/.*\n")
+        cat_chunker_complete = GistImageCaterogiesChunker(
+            gist_folder, validation_size=0, category_pattern=r"([-_\w]+)/.*\n")
+        logging.info("categories: %r", cat_chunker_complete.get_class_names())
         logging.info("training using category information (supervised)")
-        cat_chunker = GistImageCaterogiesChunker(gist_folder, validation_size=2)
 
-        mlp_pretrained = extract_encoder(autoencoders[:-1], lambda_=1e-3)
-        mlp_out_dimensions = (
-            mlp_pretrained.label_dim, cat_chunker.category_size)
-        mlp_out = make_mlp(mlp_out_dimensions, loss='huber', lambda_=1e-3)
-        mlp = chain(mlp_pretrained, mlp_out)
-        mlp.train_from_chunker(cat_chunker, epochs=20,
+        #mlp_pretrained = extract_encoder(autoencoders[:-1], lambda_=1e-3)
+        #mlp_out_dimensions = (
+        #    mlp_pretrained.label_dim, cat_chunker.category_size)
+        #mlp_out = make_mlp(mlp_out_dimensions, loss='huber', lambda_=1e-3)
+        #mlp = chain(mlp_pretrained, mlp_out)
+        mlp = make_mlp((960, cat_chunker.category_size),
+                       loss='huber', lambda_=1e-4, output="linear")
+        mlp.train_from_chunker(cat_chunker, epochs=5,
                                validation_interval=10,
                                improvement_tol=1e-3)
-        cat_chunker_complete = GistImageCaterogiesChunker(gist_folder,
-                                                          validation_size=0)
         classifier = OneAgainstAllClassifier(
             mlp, cat_chunker_complete.get_class_names())
         classifier.log_report(cat_chunker_complete)