Source

Forth interpreter / machine / cegis.rkt

Full commit
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
#lang racket

(require racket/system openssl/sha1 "programs.rkt" "stack.rkt" "state.rkt" "interpreter.rkt" "greensyn.rkt")

(provide optimize)
(provide estimate-time program-length perf-mode)
(provide z3 read-sexps)

(define debug #t)
(define demo #t)
(define current-step 0) ; the number of the current cegis step
(define current-run 0)  ; the number of the current call to cegis.

(define comm-length 1)
(define all-pairs '())

(define (initialize)
  (system "mkdir debug")
  (set! comm-length 1)
  (set! all-pairs '()))

(define (finalize)
  (system "rm -r debug"))

(define (set-comm-length comm)
  (define (check-and-set entry)
    (when (> (entry comm) comm-length)
  	  (set! comm-length (entry comm))))
  (check-and-set commstate-sendp-u)
  (check-and-set commstate-sendp-d)
  (check-and-set commstate-sendp-l)
  (check-and-set commstate-sendp-r)
  (check-and-set commstate-recvp-u)
  (check-and-set commstate-recvp-d)
  (check-and-set commstate-recvp-l)
  (check-and-set commstate-recvp-r)
)

(define (perf-mode)
  (set! debug #f)
  (set! demo #f))

;;; Returns a file name with the given prefix, containing the current
;;; step and run. You can also optionally specify a suffix like
;;; `.smt2'. Note that the `.' in `.smt2' is not added automatically.
(define (temp-file-name name prefix [suffix ""])
  (format "debug/~a-~a-~a-~a~a" name prefix current-run current-step 
	  ;(substring (sha1 (open-input-string (format "~a" (current-inexact-milliseconds)))) 0 10)
	  suffix))

;;; Run z3 on the given file, returning all output as a string.
(define (z3 file)
  (with-output-to-string (lambda () (system (format "z3 ~a" file)))))

;;; Return 'lt if x1 < x2, 'eq if x1 = x2 and 'gt if x1 > x2. Compares
;;; numbers as numbers; otherwise compares as strings.
(define (compare x1 x2)
  (let ([x1-num (string->number x1)]
        [x2-num (string->number x2)])
    (if (and x1-num x2-num)
        (cond [(< x1-num x2-num) 'lt]
              [(= x1-num x2-num) 'eq]
              [(> x1-num x2-num) 'gt])
        (cond [(string<? x1 x2) 'lt]
              [(string=? x1 x2) 'eq]
              [(string>? x1 x2) 'gt]))))

;;; Orders variables by name and then number.
(define (var-name<? v1 v2)
  (let* ([v1-parts (regexp-split "_" v1)]
         [v2-parts (regexp-split "_" v2)]
         [len (min (length v1-parts) (length v2-parts))]
         [res
          (foldl (lambda (part res)
                   (if (equal? res 'eq)
                       (compare (car part) (cdr part))
                       res)) 'eq (map cons (take v1-parts len) (take v2-parts len)))])
    (if (equal? res 'eq)
        (< (length v1-parts) (length v2-parts))
        (equal? res 'lt))))

;;; Creates an alist of variable names and values from a z3 model.
(define (extract-model model)
  (define (fun->pair fun) ; given a (define-fun ....), gives you a pair.
    `(,(list-ref fun 1) ,(list-ref fun 4)))
  (sort (map fun->pair (cdr model)) var-name<? #:key (compose (curry format "~a") car)))

;;; Given a model, interprets the holes as instructions.
(define (model->program model)
  (define (is-hole var) (equal? (substring (format "~a" var) 0 2) "h_"))
  (define (process-instr res)
    (if (= 0 (cadr res))
        (format "~a" (cadr
                      (assoc
                       (string->symbol
                        (regexp-replace "h" (format "~a" (car res)) "hlit")) model)))
        (format "~a" (vector-ref choice-id (cadr res)))))
  (define (time)
      (let* ([name `total_time]
             [result (assoc name model)])
        (if result
            (cadr result)
            (error (format "~a not found in model!" name)))))
  ;(pretty-display (time))
  (cons (string-join (map process-instr (filter (compose is-hole car) model)) " ") (time)))

;;; Given a model, extract the input/output pair it corresponds
;;; to. This lets you get new pairs after running the validator.
(define (model->pair model #:mem [mem 6] prog-length)
  (define (extract-state n)
    (define (var v)
      (let* ([name (string->symbol (format "~a_~a_v0" v n))]
             [result (assoc name model)])
        (if result
            (cadr result)
            (error (format "~a not found in model!" name)))))
    (progstate (var 'a) (var 'b) 0 0 (var 'r) (var 's) (var 't)
               (stack (var 'sp) (var 'dst))
               (stack (var 'rp) (var 'rst))
               (var 'mem)))
  `(,(extract-state 0) ,(extract-state prog-length)))

;;; Extract the commstate from the model.
(define (model->commstate model prog-length)
  (define (send x) (car (cdr (assoc (string->symbol (format "send~s_v0" x)) model))))
  (define (recv x) (car (cdr (assoc (string->symbol (format "recv~s_v0" x)) model))))
  (define (sendp x) (car (cdr (assoc (string->symbol (format "sendp~s_~s_v0" x prog-length)) model))))
  (define (recvp x) (car (cdr (assoc (string->symbol (format "recvp~s_~s_v0" x prog-length)) model))))
  (commstate (send 0) (send 1) (send 2) (send 3)
             (recv 0) (recv 1) (recv 2) (recv 3)
             (sendp 0) (sendp 1) (sendp 2) (sendp 3)
             (recvp 0) (recvp 1) (recvp 2) (recvp 3)))

;;; Parses the given bitvector into a vector of 18bit numbers.
(define (bytes->vector input size)
  (define (go bytes curr-size)
    (if (= curr-size 0)
        '()
        (cons (bitwise-bit-field bytes 0 18)
              (go (arithmetic-shift bytes -18) (sub1 curr-size)))))
  (list->vector (go input size)))

;;; read all the sexps from the given string or port.
(define (read-sexps in)
  (when (string? in) (set! in (open-input-string in)))
  (define (go)
    (let ([next (read in)])
      (if (eof-object? next) '() (cons next (go)))))
  (go))

;;; If the z3 output is sat, reads in the model. If it isn't, returns
;;; #f.
(define (read-model in)
  (define input (read-sexps in))
  (and (not (member 'unsat input))
       (let ([res (filter
                   (lambda (x)
                     (or (equal? x 'sat) (equal? (car x) 'model))) input)])
         (and (member 'sat res) (extract-model (cadr res))))))

;;; Returns a random input/output pair for the given F18A program.
(define (random-pair program [memory-start 0]
		     #:start-state [start-state (random-state (expt 2 BIT))])
  (load-state! start-state)
  (load-program program memory-start)
  (set! start-state (current-state))
  (reset-p! memory-start)
  (step-program!*)

  (define comm (current-commstate))
  (set-comm-length comm)

  (cons `(,start-state ,(current-state)) comm))

;;; Add an input/output pair to greensyn.
(define (greensyn-add-pair pair comm)
  (greensyn-input (car pair))
  (greensyn-output (cadr pair))
  (greensyn-send-recv comm)
  (greensyn-commit))

;;; Generate a candidate using the specified input/output pairs. If no
;;; pairs are specified, seed the process with a randomly generated
;;; pair. The returned model is an assoc list of variable name symbols
;;; and their numerical values.
(define (generate-candidate program previous-pairs name mem slots init repeat constraint time-limit num-bits inst-pool)
  (when demo (pretty-display "     + add pair"))
  (when (null? previous-pairs) (error "No input/output pairs given!"))
  (define temp-file (temp-file-name name "syn" ".smt2"))
  (greensyn-reset mem comm-length constraint #:num-bits num-bits #:inst-pool inst-pool)
  (map greensyn-add-pair (map car previous-pairs) (map cdr previous-pairs))
  
  (greensyn-check-sat #:file temp-file slots init repeat #:time-limit time-limit)
  
  (define z3-res (z3 temp-file))
  (define result (read-model z3-res))

  (unless debug (delete-file temp-file))

  (when debug
    (call-with-output-file #:exists 'truncate (temp-file-name name "syn-model")
                           (curry display z3-res))
    (call-with-output-file #:exists 'truncate (temp-file-name name "syn-result")
                           (lambda (out)
                             (and result
                                  (map (lambda (p)
                                         (display p out) (newline out)) result))))
    (call-with-output-file #:exists 'truncate (temp-file-name name "pair")
                           (curry display (first previous-pairs)))
    (call-with-output-file #:exists 'truncate (temp-file-name name "program")
                           (lambda (file)
                             (and result (display (car (model->program result)) file)))))
  ;(when result (pretty-display "\t>> Found a candidate."))
  ;; (if result
  ;;     (model->program result)
  ;;     null))
  (and result (model->program result)))

;;; Generate a counter-example or #f if the program is valid.
(define (validate spec candidate name mem prog-length constraint num-bits inst-pool)
  (set! current-step (add1 current-step))

  (define temp-file (temp-file-name name "verify" ".smt2"))
  (greensyn-reset mem comm-length constraint #:num-bits num-bits #:inst-pool inst-pool)
  (greensyn-spec spec)
  (greensyn-verify temp-file candidate)
  (define result (read-model (z3 temp-file)))
  
  (when debug
    (call-with-output-file
        #:exists 'truncate (temp-file-name name "verifier")
      (lambda (out) (and result (map (lambda (p) (display p out) (newline out)) result)))))
  
  (unless debug (delete-file temp-file))
  ;(when result (pretty-display "\t>> Add counterexample."))
  (and result (cons (model->pair result #:mem mem prog-length)
                    (model->commstate result prog-length))))

;;; This function runs the whole CEGIS loop. It stops when validate
;;; returns #f and returns the valid synthesized program. 
(define (cegis program 
	       #:name [name "prog"]
	       #:mem [mem 1] 
	       #:slots [slots 30] 
	       #:init [init 0] 
	       #:repeat [repeat 1] 
	       #:start [start 0] 
               #:constraint [constraint constraint-all] 
	       #:time-limit [time-limit (estimate-time program)]
	       #:num-bits [num-bits 18]
	       #:inst-pool [inst-pool `no-fake]
	       #:start-state [start-state (random-state (expt 2 BIT))]
	       #:print-time [print-time #f])
  (define cegis-start (current-seconds))
  (reset! num-bits)
  (unless (nop-before-plus? program) (error "+ has to follow a nop unless it's the first instruction!"))
  (define program-for-ver (fix-@p program))
  (when demo
	(if (number? slots)
	    (pretty-display (format ">> Synthesizing a program with <= ~a instructions, whose approx runtime < ~a ns." slots (* time-limit 0.5)))
	    (pretty-display (format ">> Synthesizing a program from ~e.\n   Approx runtime < ~a ns." (regexp-replace* #rx"\n" slots " ") (* time-limit 0.5)))))
  (set! current-run (add1 current-run))
  (set! current-step 0)

  (define (go)
    (let ([candidate (generate-candidate program all-pairs name mem slots init repeat constraint time-limit num-bits inst-pool)])
      (and candidate
          (let ([new-pair (validate program-for-ver (car candidate) name mem (program-length program-for-ver) constraint num-bits inst-pool)])
            (if new-pair
		(begin
		  (set! all-pairs (cons new-pair all-pairs))
		  (go))
                (begin 
		  (when demo
			(pretty-display (format "\tFound ~e.\n\tApprox runtime = ~e ns." (car candidate) (* (cdr candidate) 0.5))))
		  candidate))))))
  (when (empty? all-pairs)
	(set! all-pairs (list (random-pair program start #:start-state start-state))))
  (define result (go))
  (when print-time (newline) 
	(pretty-display (format "Time to synthesize: ~a seconds." (- (current-seconds) cegis-start))))
  result)

(define (fastest-program program [best-so-far #f]
			 #:name       [name "prog"]
                         #:mem        [mem 1]
			 #:init       [init 0]
			 #:slots      [slots (program-length-abs program)]
			 #:repeat     [repeat 1]
                         #:start      [start mem] 
                         #:constraint [constraint constraint-all]
                         #:time-limit [time-limit (add1 (estimate-time program))]
			 #:num-bits  [num-bits 18]
			 #:inst-pool [inst-pool `no-fake]
			 #:start-state [start-state (random-state (expt 2 BIT))])
  (define start-time (current-seconds))
  (define program-for-ver (fix-@p program))
  (define candidate (cegis program #:name name
			   #:mem mem 
                           #:slots slots #:init init #:repeat repeat
                           #:start start 
			   #:constraint constraint #:time-limit time-limit
			   #:num-bits num-bits #:inst-pool inst-pool #:start-state start-state))
  (define result (if candidate
                     (fastest-program program candidate #:name name
				      #:mem mem
                                      #:slots slots #:init init #:repeat repeat
                                      #:start start 
                                      #:constraint constraint #:time-limit (cdr candidate)
				      #:num-bits num-bits #:inst-pool inst-pool
				      #:start-state start-state)
		     best-so-far))
  (when demo (when debug (pretty-display (format "Time: ~a seconds." (- (current-seconds) start-time)))))
  result)

(define (binary-search slot-min slot-max init repeat program name mem start constraint time-limit num-bits inst-pool start-state [best-so-far #f])
  (if (> slot-min slot-max)
      best-so-far
      (let* ([slot-mid (quotient (+ slot-min slot-max) 2)]
	     [candidate (cegis program 
			       #:name name
			       #:mem mem #:slots slot-mid #:init init #:repeat repeat
			       #:start start 
			       #:constraint constraint #:time-limit time-limit
			       #:num-bits num-bits #:inst-pool inst-pool)])
	(if candidate
	    (binary-search slot-min (sub1 slot-mid) init repeat program name mem start constraint (cdr candidate) num-bits inst-pool start-state candidate)
	    (binary-search (add1 slot-mid) slot-max init repeat program name mem start constraint (if best-so-far (cdr best-so-far) time-limit) num-bits inst-pool start-state best-so-far)))))

(define (fastest-program3 program [best-so-far #f]
			  #:name       [name "prog"]
			  #:mem        [mem 1]
			  #:init       [init 0]
			  #:slots      [slots (program-length-abs program)]
			  #:repeat     [repeat 1]
			  #:start      [start mem] 
			  #:constraint [constraint constraint-all]
			  #:time-limit [time-limit (estimate-time program)]
			  #:num-bits  [num-bits 18]
			  #:inst-pool [inst-pool `no-fake]
			  #:start-state [start-state (random-state (expt 2 BIT))])
  (when demo
  (pretty-display (format "original program\t: ~e" program))
  (pretty-display (format "length\t\t\t: ~a" (program-length-abs program)))
  (pretty-display (format "approx. runtime\t\t: ~a" (* time-limit 0.5))))

  (define start-time (current-seconds))

  ;; (pretty-display "PHASE 1: finding appropriate program length who runtime is less than the original.")
  (define candidate (binary-search 1 slots init repeat program name 
				   mem start constraint 
				   time-limit num-bits inst-pool start-state))


  ;; (pretty-display "PHASE 2: optimizing for runtime.")
  ;; (set! candidate 
  ;;   (if candidate
  ;; 	(fastest-program2 program candidate 
  ;; 			  #:name name #:mem mem
  ;; 			  #:slots (min (+ (program-length-abs (car candidate)) 2) slots)
  ;; 			  #:start start 
  ;; 			  #:constraint constraint #:time-limit (cdr candidate)
  ;; 			  #:num-bits num-bits #:inst-pool inst-pool)
  ;; 	#f))

  candidate)

;; Optimize for the fastest running program. The runtime is estimated by summing runtime of 
;; all instructions in the given program without considering instruction fetching time.
;;
;; Output (on display):
;; The fasted F18A program that is equivalent to the given input program. 
;; Programs that we can synthesize do not contain instructions that change 
;; the control flow of the program, which are ; ex jump call next if -if.
;; It also cannot synthesize !p instruction.
;;
;; Required arguments:
;; orig-program :: F18A program to be optimized. Literals have to be written in form of @p as in F18A, 
;;                not arrayForth (e.g. @p @p @p @p 1 2 3 4). up down left right have to be written as 
;;                UP DOWN LEFT RIGHT (with capitalized letters) and written as they are literal 
;;                (e.g. @p @p @p @p UP DOWN LEFT RIGHT). Multiport read and write are not supported.
;; 
;; Optional arguments:
;; name  :: description of the program.
;; mem   :: number of entries of memory. The more it is the longer the synthesizer takes. 
;;          Therefore, provide just enough for the program. Note that we only support storing data
;;          from memory 0th entry until mem-1'th entry and the program itself is stored starting at
;;          mem'th entry. 
;;          DEFAULT = 1
;; slots :: maximum length of the synthesized program.
;;          slots can be string when user want to provide a sketch.
;;          For example, "_ . + _" means the synthesized program contains 4 instructions.
;;          The 1st and 4st instructions can be anything. The 2nd instruction is nop, 
;;          and the 3rd instruction is plus. 
;;          DEFAULT = original program's length
;; repeat :: When slots is a sketch in form of string. repeat can be used to indicate how many time 
;;           the sketch is unrolled. For example, #:slots "dup _ _ ." #:repeat 3 means that
;;           the actual sketch is "dup _ _ . dup _ _ . dup _ _ ."
;;           DEFAULT = 1
;; init   :: Init is the additional header sketch that comes before slots.
;;           For example, #:init "over push - 2*" #:slots "dup _ _ ." #:repeat 3 means that
;;           the actual sketch is "over push - 2* dup _ _ . dup _ _ . dup _ _ ."
;;           DEFAULT = ""
;; start  :: The entry of the memory where the program is loaded to (starting from that entry).
;;           DEFAULT = mem
;; constraint :: The registers and/or stacks that contain the output you are looking for.
;;               For example, if you want to synthesize x y --> x+y, you might only care that you want
;;               register t (the top of th stack) to be equal to x+y and don't care that if other registers 
;;               and stacks are changed or not. The synthesizer always constraints reads and writes to NSWE.
;;               Use "#:constraint constraint-all" to constraint on everything 
;;               (a b r s t data-stack return-stack memory).
;;               Use "#:constraint constraint-none" to constraint on nothing except reads and writes 
;;               to NSEW.
;;               Use "#:constraint (constraint <reg> ...) to constraint on <reg>. 
;;               For example, to constraint on a and t, use "#:constraint (constraint a t)"
;;               DEFAULT = constraint-all
;; time-limit :: The maximum runtime in ns that of the synthesized program.
;;               DEFAULT = the original runtime
;; num-bits ::   number of bits of a word.
;;               DEFAULT = 18
;; inst-pool ::  Instructions available to compose the synthesized program. 
;;               #:inst-pool `no-fake = {@p @+ @b @ !+ !b ! +* 2* 2/ - + and or drop dup pop over a nop push b! a!}
;;               #:inst-pool `no-fake-no-p = `no-fake - {@p}
;;               #:inst-pool `no-mem = `no-fake - {@+ @b @ !+ !b !}
;;               #:inst-pool `no-mem = `no-mem - {@p}
;;               DEFAULT = `no-fake
;; bin-search :: When slots is a number. We perform binary search the length of the synthesized program.
;;               For example, if slots is 8, we will start searching for a program whose length is 4.
;;               If we find an equivalent program, we will search on length 2. 
;;               If not, we will search on length 6. The process keeps going like normal binary search.
;;               If bin-search is set to false, we will always search program whose length is equal
;;               to slots.
;;               DEFAULT = true

(define (optimize orig-program 
		  #:name       [name "prog"]
		  #:mem        [mem 1]
		  #:init       [init 0]
		  #:slots      [raw-slots 0]
		  #:repeat     [repeat 1]
		  #:start      [start mem] 
		  #:constraint [constraint constraint-all]
		  #:time-limit [raw-time-limit 0]
		  #:num-bits   [num-bits 18]
		  #:inst-pool  [inst-pool `no-fake]
		  #:bin-search [bin-search #t])
  (when (> mem 64) (begin (pretty-display "memory has to be less than 64!") (exit)))
  (initialize)
  (set-udlr-from-constraints mem num-bits)

  (define program (preprocess orig-program))
  (define slots raw-slots)
  (when (and (number? slots) (= slots 0))
	(set! slots (program-length-abs program)))
  (when (string? slots)
	(set! slots (fix-@p (preprocess slots))))
  (define time-limit raw-time-limit)
  (when (= time-limit 0)
	(set! time-limit (estimate-time program)))

  (define start-time (current-seconds))

  (define result 
    (if (and (number? slots) bin-search)
      (fastest-program3 program 
			#:name       name
			#:mem        mem
			#:init       init
			#:slots      slots
			#:repeat     repeat
			#:start      start
			#:constraint constraint
			#:time-limit time-limit
			#:num-bits   num-bits
			#:inst-pool  inst-pool
			#:start-state (random-state (expt 2 BIT)))
      (fastest-program program 
			#:name       name
			#:mem        mem
			#:init       init
			#:slots      slots
			#:repeat     repeat
			#:start      start
			#:constraint constraint
			#:time-limit time-limit
			#:num-bits   num-bits
			#:inst-pool  inst-pool
			#:start-state (random-state (expt 2 BIT)))))
  (when demo
  (newline)
  (if result
      (begin
	(pretty-display (format "output program\t\t: ~e" (postprocess (car result))))
	(pretty-display (format "length\t\t\t: ~a" (program-length-abs (car result))))
	(pretty-display (format "approx. runtime\t\t: ~a" (* (cdr result) 0.5)))
	(newline)
	(pretty-display "Constants for neighbor ports:")
	(pretty-display (format "UP = ~a, DOWN = ~a, LEFT = ~a, RIGHT = ~a" UP DOWN LEFT RIGHT))
	(newline))
      (pretty-display (format "No better implementation found.")))
  (pretty-display (format "Time to synthesize: ~a seconds." (- (current-seconds) start-time))))

  (when (not debug)
  	(finalize))
)