View Javadoc
1   package net.secodo.jcircuitbreaker.breaker.impl;
2   
3   import net.secodo.jcircuitbreaker.exception.TaskExecutionException;
4   import net.secodo.jcircuitbreaker.task.Task;
5   import org.junit.Test;
6   import net.secodo.jcircuitbreaker.breaker.CircuitBreaker;
7   import net.secodo.jcircuitbreaker.breaker.execution.ExecutedTask;
8   import net.secodo.jcircuitbreaker.breakhandler.BreakHandler;
9   import net.secodo.jcircuitbreaker.breakstrategy.BreakStrategy;
10  import java.lang.reflect.Field;
11  import java.util.Map;
12  import java.util.Objects;
13  import java.util.Stack;
14  import java.util.concurrent.Callable;
15  import java.util.concurrent.CountDownLatch;
16  import java.util.concurrent.TimeUnit;
17  import java.util.stream.IntStream;
18  import static org.hamcrest.CoreMatchers.equalTo;
19  import static org.hamcrest.MatcherAssert.assertThat;
20  import static org.mockito.Mockito.mock;
21  
22  
23  @SuppressWarnings("unchecked")
24  public class DefaultCircuitBreakerConcurrencyTest {
25    @Test(timeout = 20000L)
26    public void shouldCorrectlyHandleNumberOfTasksInProgress() throws Exception {
27      // given
28      final DefaultCircuitBreaker circuitBreaker = new DefaultCircuitBreaker();
29  
30      final Field tasksInProgressField = AbstractCircuitBreaker.class.getDeclaredField("tasksInProgress");
31      tasksInProgressField.setAccessible(true);
32  
33      final Map<String, ExecutedTask> tasksInProgress = (Map<String, ExecutedTask>) tasksInProgressField.get(
34        circuitBreaker);
35  
36      final BreakStrategy breakStrategy = mock(BreakStrategy.class);
37      final BreakHandler<Long> breakHandler = mock(BreakHandler.class);
38  
39  
40      //    final SomeTestClassWithLongRunningMethod someObject = new SomeTestClassWithLongRunningMethod();
41      //    Callable<Long> methodCall = () -> someObject.longRunMethod("222", 333);
42  
43      final int numberOfThreads = 10;
44      Stack<Thread> concurrentThreadsStack = new Stack<>();
45      Stack<SomeTestClassWithLongRunningMethod> objectsUnderTestStack = new Stack<>();
46  
47  
48      final CountDownLatch latch = new CountDownLatch(numberOfThreads);
49  
50  
51      IntStream.range(0, numberOfThreads).forEach(i -> {
52        final SomeTestClassWithLongRunningMethod someObject = new SomeTestClassWithLongRunningMethod();
53        final Task<Long> methodCall = () -> someObject.longRunMethod("222", 333);
54  
55        final TaskWithFixedHashcode<Long> taskWithFixedHashcode = new TaskWithFixedHashcode(methodCall); // make sure that fixed hashcode does not all tasks to
56  
57        // override each other when CircuitBreaker is running
58  
59  
60        Thread thread = new Thread(() -> {
61          try {
62            circuitBreaker.execute(taskWithFixedHashcode, breakStrategy, breakHandler);
63          } catch (TaskExecutionException e) {
64            throw new RuntimeException("Exception while running method on thread no: " + i, e);
65          }
66        });
67  
68        concurrentThreadsStack.push(thread);
69        objectsUnderTestStack.push(someObject);
70  
71      });
72  
73      concurrentThreadsStack.forEach(Thread::start);
74      concurrentThreadsStack.forEach(t -> waitUntilAllTasksAreInProgress(tasksInProgress, numberOfThreads));
75  
76  
77      assertThat(tasksInProgress.size(), equalTo(numberOfThreads));
78  
79      Thread annihilator = new Thread(() -> {
80        int numberOfTasksInProgress = tasksInProgress.size();
81  
82        while (!objectsUnderTestStack.empty()) {
83          try {
84            final SomeTestClassWithLongRunningMethod currentObject = objectsUnderTestStack.pop();
85            final Thread threadOfCurrentObject = concurrentThreadsStack.pop();
86  
87            currentObject.annihilate();
88            numberOfTasksInProgress--;
89  
90            waitUntilThreadIsDead(threadOfCurrentObject);
91  
92            assertThat(tasksInProgress.size(), equalTo(numberOfTasksInProgress));
93  
94          } finally {
95            latch.countDown();
96  
97          }
98        }
99  
100     });
101 
102     annihilator.start();
103 
104 
105     latch.await(10, TimeUnit.MINUTES);
106   }
107 
108   private void waitUntilAllTasksAreInProgress(Map<String, ExecutedTask> tasksInProgress, int expectedNumberOfTasks) {
109     final int maxLoopIterations = 200;
110     int currentLoopIteration = 0;
111     final int sleepTimeMilis = 10;
112 
113 
114     while (tasksInProgress.size() != expectedNumberOfTasks) {
115       try {
116         Thread.sleep(sleepTimeMilis);
117       } catch (InterruptedException e) {
118         // ok to continue
119       }
120 
121       currentLoopIteration++;
122 
123       if (currentLoopIteration == maxLoopIterations) {
124         throw new RuntimeException(
125           "After: " + (sleepTimeMilis * maxLoopIterations) + " miliseconds number of tasks in progress within " +
126           CircuitBreaker.class.getSimpleName() + " was: " +
127           tasksInProgress.size() + ", but expected " + expectedNumberOfTasks +
128           ". It may indicate a problem in calculating Key for internal " +
129           "map storing tasks in progress.");
130       }
131 
132     }
133   }
134 
135   private void waitUntilThreadIsDead(Thread thread) {
136     while (thread.isAlive()) {
137       try {
138         Thread.sleep(10);
139       } catch (InterruptedException e) {
140         // ok to continue
141       }
142     }
143   }
144 
145   class SomeTestClassWithLongRunningMethod extends Thread {
146     private boolean shouldRun = true;
147     private boolean started = false;
148 
149     SomeTestClassWithLongRunningMethod() {
150     }
151 
152     public Long longRunMethod(String paramValue1, Integer paramValue2) {
153       started = true;
154 
155       while (shouldRun) {
156         try {
157           Thread.sleep(100);
158         } catch (InterruptedException e) {
159           // interrupted
160         }
161       }
162 
163       return (long) 1;
164     }
165 
166     public void annihilate() {
167       shouldRun = false;
168     }
169 
170     public boolean isStarted() {
171       return started;
172     }
173   }
174 
175   class TaskWithFixedHashcode<R> implements Task<R> {
176     private final Task<R> methodCall;
177 
178     TaskWithFixedHashcode(Task<R> c) {
179       this.methodCall = c;
180     }
181 
182 
183     @Override
184     public R execute() throws Exception {
185       return methodCall.execute();
186     }
187 
188     @Override
189     public boolean equals(Object o) {
190       if (this == o) {
191         return true;
192       }
193       if ((o == null) || (getClass() != o.getClass())) {
194         return false;
195       }
196 
197       TaskWithFixedHashcode<?> that = (TaskWithFixedHashcode<?>) o;
198       return Objects.equals(methodCall, that.methodCall);
199     }
200 
201     @Override
202     public int hashCode() {
203       return 5;
204     }
205   }
206 
207 }