Source

stm-benchmarks / raytrace / raytrace.py

# From http://www.reddit.com/r/tinycode/comments/169ri9/ray_tracer_in_140_sloc_of_python_with_picture/
# Date: 14.03.2013

from math import sqrt, pow, pi
from common.abstract_threading import atomic, Future

AMBIENT = 0.1
GAMMA_CORRECTION = 1/2.2


class Vector:
    def __init__(self,x,y,z):
        self.x = x
        self.y = y
        self.z = z

    def dot(self, b):
        return self.x*b.x + self.y*b.y + self.z*b.z

    def cross(self, b):
        return (self.y*b.z-self.z*b.y, self.z*b.x-self.x*b.z, self.x*b.y-self.y*b.x)

    def magnitude(self):
        return sqrt(self.x**2+self.y**2+self.z**2)

    def normal(self):
        mag = self.magnitude()
        return Vector(self.x/mag,self.y/mag,self.z/mag)

    def __add__(self, b):
        return Vector(self.x + b.x, self.y+b.y, self.z+b.z)

    def __sub__(self, b):
        return Vector(self.x-b.x, self.y-b.y, self.z-b.z)

    def __mul__(self, b):
        assert type(b) == float or type(b) == int
        return Vector(self.x*b, self.y*b, self.z*b)


class Sphere:
    def __init__(self, center, radius, color):
        self.c = center
        self.r = radius
        self.col = color

    def intersection(self, l):
        q = l.d.dot(l.o - self.c)**2 - (l.o - self.c).dot(l.o - self.c) + self.r**2
        if q < 0:
            return Intersection( Vector(0,0,0), -1, Vector(0,0,0), self)
        else:
            d = -l.d.dot(l.o - self.c)
            d1 = d - sqrt(q)
            d2 = d + sqrt(q)
            if 0 < d1 and ( d1 < d2 or d2 < 0):
                return Intersection(l.o+l.d*d1, d1, self.normal(l.o+l.d*d1), self)
            elif 0 < d2 and ( d2 < d1 or d1 < 0):
                return Intersection(l.o+l.d*d2, d2, self.normal(l.o+l.d*d2), self)
            else:
                return Intersection( Vector(0,0,0), -1, Vector(0,0,0), self)

    def normal(self, b):
        return (b - self.c).normal()


class Plane:
    def __init__(self, point, normal, color):
        self.n = normal
        self.p = point
        self.col = color

    def intersection(self, l):
        d = l.d.dot(self.n)
        if d == 0:
            return Intersection( Vector(0,0,0), -1, Vector(0,0,0), self)
        else:
            d = (self.p - l.o).dot(self.n) / d
            return Intersection(l.o+l.d*d, d, self.n, self)


class Ray:
    def __init__(self, origin, direction):
        self.o = origin
        self.d = direction


class Intersection:
    def __init__(self, point, distance, normal, obj):
        self.p = point
        self.d = distance
        self.n = normal
        self.obj = obj


def testRay(ray, objects, ignore=None):
    intersect = Intersection( Vector(0,0,0), -1, Vector(0,0,0), None)

    for obj in objects:
        if obj is not ignore:
            currentIntersect = obj.intersection(ray)
            if currentIntersect.d > 0 and intersect.d < 0:
                intersect = currentIntersect
            elif 0 < currentIntersect.d < intersect.d:
                intersect = currentIntersect
    return intersect


def trace(ray, objects, light, maxRecur):
    if maxRecur < 0:
        return (0,0,0)
    intersect = testRay(ray, objects)
    if intersect.d == -1:
        col = Vector(AMBIENT,AMBIENT,AMBIENT)
    elif intersect.n.dot(light - intersect.p) < 0:
        col = intersect.obj.col * AMBIENT
    else:
        lightRay = Ray(intersect.p, (light-intersect.p).normal())
        if testRay(lightRay, objects, intersect.obj).d == -1:
            lightIntensity = 1000.0/(4*pi*(light-intersect.p).magnitude()**2)
            col = intersect.obj.col * max(intersect.n.normal().dot((light - intersect.p).normal()*lightIntensity), AMBIENT)
        else:
            col = intersect.obj.col * AMBIENT
    return col


def gammaCorrection(color,factor):
    return (int(pow(color.x/255.0,factor)*255),
            int(pow(color.y/255.0,factor)*255),
            int(pow(color.z/255.0,factor)*255))


def task(x, h, cameraPos, objs, lightSource):
    with atomic:
        for y in range(h):
            ray = Ray(cameraPos,
                      (Vector(x/50.0-5,y/50.0-5,0)-cameraPos).normal())
            trace(ray, objs, lightSource, 10)

futures = []
def future_dispatcher(ths, *args):
    if len(futures) >= ths:
        for f in futures:
            f()
        del futures[:]

    futures.append(Future(task, *args))




def run(ths=4, w=300, h=300):
    ths = int(ths)
    w = int(w)
    h = int(h)

    objs = []
    objs.append(Sphere( Vector(-2,0,-10), 2, Vector(0,255,0)))
    objs.append(Sphere( Vector(2,0,-10), 3.5, Vector(255,0,0)))
    objs.append(Sphere( Vector(0,-4,-10), 3, Vector(0,0,255)))
    objs.append(Plane( Vector(0,0,-12), Vector(0,0,1), Vector(255,255,255)))
    lightSource = Vector(0,10,0)

    cameraPos = Vector(0,0,20)
    
    for x in range(w):
        print x
        future_dispatcher(ths, x, h, cameraPos, objs, lightSource)

    for f in futures:
        f()
    del futures[:]



if __name__ == '__main__':
    run()