1 package net.secodo.jcircuitbreaker.breaker.impl;
2
3 import net.secodo.jcircuitbreaker.exception.TaskExecutionException;
4 import net.secodo.jcircuitbreaker.task.Task;
5 import org.junit.Test;
6 import net.secodo.jcircuitbreaker.breaker.CircuitBreaker;
7 import net.secodo.jcircuitbreaker.breaker.execution.ExecutedTask;
8 import net.secodo.jcircuitbreaker.breakhandler.BreakHandler;
9 import net.secodo.jcircuitbreaker.breakstrategy.BreakStrategy;
10 import java.lang.reflect.Field;
11 import java.util.Map;
12 import java.util.Objects;
13 import java.util.Stack;
14 import java.util.concurrent.Callable;
15 import java.util.concurrent.CountDownLatch;
16 import java.util.concurrent.TimeUnit;
17 import java.util.stream.IntStream;
18 import static org.hamcrest.CoreMatchers.equalTo;
19 import static org.hamcrest.MatcherAssert.assertThat;
20 import static org.mockito.Mockito.mock;
21
22
23 @SuppressWarnings("unchecked")
24 public class DefaultCircuitBreakerConcurrencyTest {
25 @Test(timeout = 20000L)
26 public void shouldCorrectlyHandleNumberOfTasksInProgress() throws Exception {
27
28 final DefaultCircuitBreaker circuitBreaker = new DefaultCircuitBreaker();
29
30 final Field tasksInProgressField = AbstractCircuitBreaker.class.getDeclaredField("tasksInProgress");
31 tasksInProgressField.setAccessible(true);
32
33 final Map<String, ExecutedTask> tasksInProgress = (Map<String, ExecutedTask>) tasksInProgressField.get(
34 circuitBreaker);
35
36 final BreakStrategy breakStrategy = mock(BreakStrategy.class);
37 final BreakHandler<Long> breakHandler = mock(BreakHandler.class);
38
39
40
41
42
43 final int numberOfThreads = 10;
44 Stack<Thread> concurrentThreadsStack = new Stack<>();
45 Stack<SomeTestClassWithLongRunningMethod> objectsUnderTestStack = new Stack<>();
46
47
48 final CountDownLatch latch = new CountDownLatch(numberOfThreads);
49
50
51 IntStream.range(0, numberOfThreads).forEach(i -> {
52 final SomeTestClassWithLongRunningMethod someObject = new SomeTestClassWithLongRunningMethod();
53 final Task<Long> methodCall = () -> someObject.longRunMethod("222", 333);
54
55 final TaskWithFixedHashcode<Long> taskWithFixedHashcode = new TaskWithFixedHashcode(methodCall);
56
57
58
59
60 Thread thread = new Thread(() -> {
61 try {
62 circuitBreaker.execute(taskWithFixedHashcode, breakStrategy, breakHandler);
63 } catch (TaskExecutionException e) {
64 throw new RuntimeException("Exception while running method on thread no: " + i, e);
65 }
66 });
67
68 concurrentThreadsStack.push(thread);
69 objectsUnderTestStack.push(someObject);
70
71 });
72
73 concurrentThreadsStack.forEach(Thread::start);
74 concurrentThreadsStack.forEach(t -> waitUntilAllTasksAreInProgress(tasksInProgress, numberOfThreads));
75
76
77 assertThat(tasksInProgress.size(), equalTo(numberOfThreads));
78
79 Thread annihilator = new Thread(() -> {
80 int numberOfTasksInProgress = tasksInProgress.size();
81
82 while (!objectsUnderTestStack.empty()) {
83 try {
84 final SomeTestClassWithLongRunningMethod currentObject = objectsUnderTestStack.pop();
85 final Thread threadOfCurrentObject = concurrentThreadsStack.pop();
86
87 currentObject.annihilate();
88 numberOfTasksInProgress--;
89
90 waitUntilThreadIsDead(threadOfCurrentObject);
91
92 assertThat(tasksInProgress.size(), equalTo(numberOfTasksInProgress));
93
94 } finally {
95 latch.countDown();
96
97 }
98 }
99
100 });
101
102 annihilator.start();
103
104
105 latch.await(10, TimeUnit.MINUTES);
106 }
107
108 private void waitUntilAllTasksAreInProgress(Map<String, ExecutedTask> tasksInProgress, int expectedNumberOfTasks) {
109 final int maxLoopIterations = 200;
110 int currentLoopIteration = 0;
111 final int sleepTimeMilis = 10;
112
113
114 while (tasksInProgress.size() != expectedNumberOfTasks) {
115 try {
116 Thread.sleep(sleepTimeMilis);
117 } catch (InterruptedException e) {
118
119 }
120
121 currentLoopIteration++;
122
123 if (currentLoopIteration == maxLoopIterations) {
124 throw new RuntimeException(
125 "After: " + (sleepTimeMilis * maxLoopIterations) + " miliseconds number of tasks in progress within " +
126 CircuitBreaker.class.getSimpleName() + " was: " +
127 tasksInProgress.size() + ", but expected " + expectedNumberOfTasks +
128 ". It may indicate a problem in calculating Key for internal " +
129 "map storing tasks in progress.");
130 }
131
132 }
133 }
134
135 private void waitUntilThreadIsDead(Thread thread) {
136 while (thread.isAlive()) {
137 try {
138 Thread.sleep(10);
139 } catch (InterruptedException e) {
140
141 }
142 }
143 }
144
145 class SomeTestClassWithLongRunningMethod extends Thread {
146 private boolean shouldRun = true;
147 private boolean started = false;
148
149 SomeTestClassWithLongRunningMethod() {
150 }
151
152 public Long longRunMethod(String paramValue1, Integer paramValue2) {
153 started = true;
154
155 while (shouldRun) {
156 try {
157 Thread.sleep(100);
158 } catch (InterruptedException e) {
159
160 }
161 }
162
163 return (long) 1;
164 }
165
166 public void annihilate() {
167 shouldRun = false;
168 }
169
170 public boolean isStarted() {
171 return started;
172 }
173 }
174
175 class TaskWithFixedHashcode<R> implements Task<R> {
176 private final Task<R> methodCall;
177
178 TaskWithFixedHashcode(Task<R> c) {
179 this.methodCall = c;
180 }
181
182
183 @Override
184 public R execute() throws Exception {
185 return methodCall.execute();
186 }
187
188 @Override
189 public boolean equals(Object o) {
190 if (this == o) {
191 return true;
192 }
193 if ((o == null) || (getClass() != o.getClass())) {
194 return false;
195 }
196
197 TaskWithFixedHashcode<?> that = (TaskWithFixedHashcode<?>) o;
198 return Objects.equals(methodCall, that.methodCall);
199 }
200
201 @Override
202 public int hashCode() {
203 return 5;
204 }
205 }
206
207 }