Python-Java IPC, Redis or Plain Socket?

Subscribe Send me a message home page tags


The performance of Python programs is limited by GIL. One way to get around this issue is using IPC techniques. For example, we could set a socket connection between Python and Java. In this way, we can leverage the multithreading feature provided by Java and boost the performance. Another way to handle the communication between two processes is to use pub/sub feature provided in redis. In this post, we will present both approaches.

Testing Process

On the python side, the code will generate messages with different sizes and send them to the Java program. Upon receiving the message, the Java code will simply echo the received message.

We will generate 100 message for each size.

Use Pub/Sub in Redis

Python code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import redis
import time
import pandas as pd

CHANNEL_REQUEST = "request"
CHANNEL_RESPONSE = "response"

redisClient = redis.Redis(host='localhost', port=7379)
pubsub = redisClient.pubsub()
pubsub.psubscribe(CHANNEL_RESPONSE)

kb = 1000
listOfMessageSize = [kb, 10*kb, 50*kb, 100* kb, 500 * kb, 1000 * kb, 2000*kb, 2500*kb, 3000*kb][::-1]
count = 100

records = []
for messageSize in listOfMessageSize:
    print("message size: {}".format(messageSize))
    for k in range(count):
        print("k = {}".format(k))
        message = "x" * messageSize

        start = time.time_ns()
        redisClient.publish(CHANNEL_REQUEST, message)
        m = None
        while m is None:
            m = pubsub.get_message(0.005)
        end = time.time_ns()

        assert m['data'].decode("utf-8") == message
        records.append([messageSize, end - start])
        time.sleep(0.1)

Java code is given as follows. Note that we need to have two Jedis clients because once a client subscribes to topics, it cannot publish messages any more.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
package language.socket;

import lombok.extern.slf4j.Slf4j;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPubSub;


@Slf4j
public class RedisClient {
    private static String HOST = "127.0.0.1";
    private static int PORT = 7379;

    public static void main(String[] args) {
        log.info("Starting the program");

        final Jedis jedis = new Jedis(HOST, PORT, 60);
        final Jedis redisClientPublisher = new Jedis(HOST, PORT, 60);

        try {

            JedisPubSub jedisPubSub = new JedisPubSub() {

                @Override
                public void onMessage(String channel, String message) {
                    log.info("Received message from channel {}. size of messaage = {} ", channel, message.length());
                    long start = System.nanoTime();
                    redisClientPublisher.publish("response", message);
                    log.info("Redis,{},{}", System.nanoTime() - start, message.length());
                }
            };

            jedis.subscribe(jedisPubSub, "request");

        } finally {
            if (jedis != null) {
                jedis.close();
            }
        }
    }
}

Use Socket

Python Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import socket
import sys
import logging
import pandas as pd
import time

logging.basicConfig(level=logging.INFO,
                    stream=sys.stdout,
                    format='%(asctime)s.%(msecs)d [%(levelname)s] (%(name)s) %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

log = logging.getLogger("main")
log.setLevel(logging.DEBUG)

class Socket:

    BUFFER_SIZE = 5 * 1024 * 1024
    MESSAGE_END = b'\n'
    MESSAGE_END_INT_VALUE = int.from_bytes(MESSAGE_END, 'little')

    def __init__(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.sock.settimeout(0.1)
        self.buffer = bytearray(Socket.BUFFER_SIZE)

    def connect(self, host, port):
        self.sock.connect((host, port))

    def _send(self, msg):
        totalsent = 0

        if isinstance(msg, str):
            msg = msg.encode("utf-8")

        msgLen = len(msg)
        while totalsent < msgLen:
            stopIndex = min(msgLen, totalsent + Socket.BUFFER_SIZE)

            sent = self.sock.send(msg[totalsent:stopIndex])
            if sent == 0:
                raise RuntimeError("socket connection broken")
            totalsent = totalsent + sent

    def send(self, msg):
        self._send(msg)
        self._send(Socket.MESSAGE_END)

    def receive(self):
        """
        The current implementation assumes that we don't receive multiple responses for one request.

        :return: Response bytes
        """
        # numOfBytesReceived = 0 if len(self.listOfBufferedBytesObject) == 0 else len(self.listOfBufferedBytesObject[0])
        # log.debug("number of bytes in the buffer: {}".format(numOfBytesReceived))

        numOfBytesReceived = 0

        c = 0

        while True:
            try:
                s = time.time_ns()
                chunk = self.sock.recv(4096)
                c += time.time_ns() - s
                if chunk == b'':
                    raise RuntimeError("socket connection broken")

                self.buffer[numOfBytesReceived: numOfBytesReceived + len(chunk)] = chunk
                numOfBytesReceived += len(chunk)

                if chunk[-1] == Socket.MESSAGE_END_INT_VALUE:
                    break

            except socket.timeout as e:
                log.error(e)
                pass

        return bytes(self.buffer[:(numOfBytesReceived-1)])



def test(mySocket1):
    kb = 1000
    listOfMessageSize = [kb, 10*kb, 50*kb, 100* kb, 500 * kb, 1000 * kb, 2000*kb, 2500*kb, 3000*kb][::-1]

    count = 100

    records = []
    for messageSize in listOfMessageSize:
        print("message size: {}".format(messageSize))
        for k in range(count):
            print("k = {}".format(k))
            message = "x" * messageSize

            start = time.time_ns()
            mySocket1.send(message)

            t2 = time.time_ns()
            m =mySocket1.receive()
            end = time.time_ns()

            assert m.decode("utf-8") == message
            records.append([messageSize, end - start, "s-total"])

            time.sleep(0.1)

mySocket1 = Socket()
mySocket1.connect(host = "127.0.0.1", port = 4000)
test(mySocket1)

Java code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package language.socket;

import lombok.extern.slf4j.Slf4j;

import java.net.*;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

@Slf4j
public class Server {
    private static int PORT_NUMBER = 4000;

    private static void handleSocketConnection(Socket clientSocket) {
        String clientSocketInfo = clientSocket.getInetAddress() + ":" + clientSocket.getPort();
        log.info("Start a new connection {}", clientSocketInfo);


        final int internalReceiveSize = 4096;

        try (
                PrintWriter out = new PrintWriter(clientSocket.getOutputStream(), false);
                OutputStream os = clientSocket.getOutputStream();
                BufferedReader in = new BufferedReader(new InputStreamReader(clientSocket.getInputStream()));
                InputStream inputStream = clientSocket.getInputStream();)
        {
            clientSocket.setTcpNoDelay(true);
            log.info("Send Buffer size: {}", clientSocket.getSendBufferSize());
            log.info("Recv Buffer size: {}", clientSocket.getReceiveBufferSize());
//            clientSocket.setSendBufferSize(131072);
            clientSocket.setSendBufferSize(2310720);
            clientSocket.setReceiveBufferSize(1024 * 1024);
            String inputLine = "";


            byte[] buffer = new byte[5 * 1024 * 1024];

            while (true) {
                long t0 = System.nanoTime();

                int numOfBytesReceived = 0;

                while (true) {
                    int n = inputStream.read(buffer, numOfBytesReceived, internalReceiveSize);
                    numOfBytesReceived += n;

                    if (buffer[numOfBytesReceived-1] == '\n') {
                        break;
                    }
                }

                inputLine = new String(buffer, 0, numOfBytesReceived-1, StandardCharsets.UTF_8);

                long t1 = System.nanoTime();

                log.info("Received message of size {} from {}", inputLine.length(), clientSocketInfo);

                final byte[] bytes = inputLine.getBytes("UTF-8");

                final int step = 8192;
                int k = 0;
                int l;
                while (k < bytes.length) {

                    if (k + step < bytes.length) {
                        l = step;
                    } else {
                        l = bytes.length - k;
                    }
                    os.write(bytes, k, l);
                    os.flush();
                    k += l;
                }
                os.write('\n');
                long t2 = System.nanoTime();

                log.info("socket,total,{},{}", t2 - t0, inputLine.length());
                log.info("socket,recv,{},{}", t1 - t0, inputLine.length());
                log.info("socket,send,{},{}", t2 - t1, inputLine.length());

            }

        } catch (final IOException e) {
            throw new RuntimeException("Socket connection error.");
        } finally {
            try {
                log.info("Close socket connection {}", clientSocketInfo);
                clientSocket.close();
            } catch (IOException e) {
                log.error("IOException when closing socket.");
            }
        }

    }


    public static void main(String[] args) throws IOException {
        log.info("Starting program.");
        ExecutorService executorService = Executors.newFixedThreadPool(20);
        ServerSocket serverSocket = new ServerSocket(PORT_NUMBER);

        while (true) {
            try {
                Socket clientSocket = serverSocket.accept();
                executorService.submit(() -> handleSocketConnection(clientSocket));
            } catch (Exception e) {

            }
            Thread.yield();

        }
    }
}

Result

The detailed data is atatched at the end of this post. Using socket is faster than redis because

Taking all the item into accounts, the performance of Redis is quite good. Sending 1M data back and forth takes 7ms using Redis.

Median Latency in millisecond:

messageSize redis socket
1KB 0.83 0.4
10KB 0.86 0.48
50KB 1.3 0.79
100KB 1.72 1.19
500KB 4.77 4.46
1000KB 6.65 6.26
2000KB 11.94 10.27
2500KB 17.93 14.77
3000KB 27.15 20.8

Average Latency in millisecond:

messageSize redis socket
1KB 0.88 0.45
10KB 0.91 0.58
50KB 1.32 0.95
100KB 1.94 1.5
500KB 5.45 4.72
1000KB 7.23 6.22
2000KB 13.29 10.93
2500KB 19.77 16.83
3000KB 31.61 26.87

Histograms:

hist-50-kb.png
hist-500-kb.png
hist-1-kb.png
hist-10-kb.png
hist-100-kb.png
hist-1000-kb.png
hist-2000-kb.png
hist-2500-kb.png
hist-3000-kb.png

----- END -----

Welcome to join reddit self-learning community.
Send me a message Subscribe to blog updates

Want some fun stuff?

/static/shopping_demo.png