上文我们学习了ThreadLocal的基本用法以及基本原理,ThreadLocal中的方法并不多,基本用到的也就get、set、remove等方法,但是其核心逻辑还是在定义在ThreadLocal内部的静态内部类ThreadLocalMap中,里面有很多设计非常精妙的地方,本文中我们就从ThreadLocalMap的角度入手深入学习ThreadLocal的原理。
按照官方的解释是:这是一个定制化的Hash类型的map,专门用来保存线程本地变量。其内部采用是通过一个自定义的Entry来封装数据,并且保存在一个Entry数组中。为了便于处理大量且长时间存活的对象引用(其实是ThreadLocal),Entry采用WeakReference作为key的类型,当map中空间不够时,key为null的ertry将会被删除。ThreadLocalMap内部数据结构如下:
static class ThreadLocalMap { static class Entry extends WeakReference<ThreadLocal<?>> { /** 要保存到线程本地的变量 */ Object value; Entry(ThreadLocal<?> k, Object v) { super(k); value = v; } } /** * 数组初始容量 -- 必须为2的倍数. */ private static final int INITIAL_CAPACITY = 16; /** * 存储entry的数组,长度为2的倍数 */ private Entry[] table; /** * entries数量 */ private int size = 0; /** * resize阈值 */ private int threshold; // Default to 0 /** * 计算阈值 */ private void setThreshold(int len) { threshold = len * 2 / 3; } /** * i+1,大于等于len则从0开始继续 */ private static int nextIndex(int i, int len) { return ((i + 1 < len) ? i + 1 : 0); } /** * i-1,小于0则从len-1开始继续 */ private static int prevIndex(int i, int len) { return ((i - 1 >= 0) ? i - 1 : len - 1); } ...... }
在ThreadLocalMap内部通过自定义的Entry类来封装要保存的数据,以ThreadLocal类型对象为key,Object类型对象为value。这个Entry继承自WeakReference<ThreadLocal<?>>,每个Entry都可以是一个指向ThreadLocal对象的弱引用,可通过Entry的get方法来获取对ThreadLocal对象的引用,而这个引用就是key。所有的Entry统一保存在一个Entry数组table中,数组的长度必须为2的倍数,通过key的hashcode与数组长度减1进行与运算来定位Entry在数组中的存储位置,这点和hashmap类似,但是当发生hash碰撞时hashmap的处理方法是放入链表或者树中(都在同一个hash桶中),而ThreadLocalMap则是依次往后查找可以保存的地方,没有桶的概念(这点后面会结合代码详细讲)。
既然ThreadLocalMap内部是一个数组,通过key的hashcode来定位到数组下标,这里我们不得不说一下key的hashcode的生成方式,非常精妙,因为key类型为ThreadLocal,所以其hashcode的生成方式也在ThreadLocal中:
private final int threadLocalHashCode = nextHashCode(); private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647; private static int nextHashCode() { return nextHashCode.getAndAdd(HASH_INCREMENT); }
对于每个ThreadLocal对象,都有一个独自不变的hashcode,每新增一个ThreadLocal对象,会自动生成其自己的hashcode,其实就是让nextHashCode自增0x61c88647,目的是为了让生成的hashcode均匀的分布在2的幂次方上,而数组长度也是2的幂次方,这样就保证了要插入的元素可以均匀分布在数组中。
虽然ThreadLocal使用了很牛逼的办法来生成hashcode,但是还是不可避免会产生hash碰撞,当出现碰撞时是如何来处理呢?我们接着看:
我们知道ThreadLocalMap是以Entry为基本单元保存数据的,而且是以key-value对的形式,我们先来看一下是如何通过key获取到Entry的:
private Entry getEntry(ThreadLocal<?> key) { int i = key.threadLocalHashCode & (table.length - 1); Entry e = table[i]; if (e != null && e.get() == key) return e; else return getEntryAfterMiss(key, i, e); }
这个逻辑比较简单:
/**
* 有三种情况下会执行这个方法
* 1. e为null;
* 2. e!=null,e的key=null;
* 3. e!=null,e的key!=null,e的key!=要找的key,即出现hash碰撞
**/
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { Entry[] tab = table; int len = tab.length; while (e != null) { ThreadLocal<?> k = e.get(); if (k == key) return e; if (k == null) expungeStaleEntry(i); else i = nextIndex(i, len); // 出现碰撞,则依次往后找 e = tab[i]; } return null; }
这里的逻辑也比较清晰:
对于key为空的Entry在ThreadLocal里面称为staleSlot,接下来看一下expungeStaleEntry:
private int expungeStaleEntry(int staleSlot) { Entry[] tab = table; int len = tab.length; // 直接将下标为staleSlot处的元素擦除,value和Entry都要擦除 tab[staleSlot].value = null; tab[staleSlot] = null; size--; // Rehash操作直到数组对应下标处元素为空的情况 Entry e; int i; for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); if (k == null) { e.value = null; tab[i] = null; size--; } else { int h = k.threadLocalHashCode & (len - 1); if (h != i) { tab[i] = null; while (tab[h] != null) h = nextIndex(h, len); tab[h] = e; } } } return i; }
逻辑会稍微复杂一些,我们还是一步一步看:
expungeStaleEntry的作用是清除传入的staleSlot处的Entry,除此之外还会管两件"闲事":
从上面的分析我们得出,通过key获取元素时,如果从计算出来的下标能获取到符合要求的值则直接返回,否则会从该位置开始依次往后找;遇到Entry不为空但是Entry的key为空的会擦除该Entry并继续循环;遇到Entry不为空且key不为空(hash碰撞)则直接往后找;在整个找的过程中遇到Entry为null则停止查找,直接返回null。
接下来我们看看设置元素,也就是set方法:
private void set(ThreadLocal<?> key, Object value) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { ThreadLocal<?> k = e.get(); // 找到则直接替换,然后直接返回 if (k == key) { e.value = value; return; } // 发现staleSlot,则执行replaceStaleEntry,然后直接返回 if (k == null) { replaceStaleEntry(key, value, i); return; } } // 如果没有找到,则new一个Entry插入数组中 tab[i] = new Entry(key, value); int sz = ++size;
// 插入新的Etry之后需要试探的去擦除一些过期的slot(key=null的Entry),如果Entry数量大于阈值,则执行扩容 if (!cleanSomeSlots(i, sz) && sz >= threshold) rehash(); }
这也是一个私有方法,这里看起来代码不多,但是里面涉及到的东西很多,逻辑也要比get方法复杂,但是没关系,我们层层递进,一一分解。
上面的步骤看完之后,我们来看看其中当key为空时需要执行的replaceStaleEntry的逻辑:
private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) { Entry[] tab = table; int len = tab.length; Entry e; // 现在staleSlot处对应的Entry其key=null,往前查找看是否能不能找到一个stale的Entry int slotToExpunge = staleSlot; for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len)) if (e.get() == null) slotToExpunge = i; // Find either the key or trailing null slot of run, whichever // occurs first for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) { ThreadLocal<?> k = e.get(); // 找到了直接替换,替换之后再尝试删除一些stale的Entry if (k == key) { e.value = value; tab[i] = tab[staleSlot]; tab[staleSlot] = e; // Start expunge at preceding stale entry if it exists if (slotToExpunge == staleSlot) slotToExpunge = i; cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); return; } // 如果i处对应的Entry是stale,并且前面往前没有找到stale的Entry,则将i标识为待擦除的slot if (k == null && slotToExpunge == staleSlot) slotToExpunge = i; } // 如果没有找到传入key对应的entry,则new一个新Entry放在传入staleSlot下标处,现在staleSlot处的Entry不再是stale(过期的)了 tab[staleSlot].value = null; tab[staleSlot] = new Entry(key, value); // 如果还发现有其他stale entries存在, 将其清除 if (slotToExpunge != staleSlot) cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); }
这个replaceStaleEntry的逻辑比较难理解,只要清楚它主要干了下面两件事:
以及插入一个新的Entry之后,试探性地去删除多余的staleSlot(注意,是试探性的哦),逻辑在cleanSomeSlots中:
/**
* @param i 扫描起始下标,从第i+1处开始扫描
*
* @param n 扫描次数控制量,在往后面扫描的过程中,如果没有发现staleSlot,则最多扫描log2(n)个元素,否则在staleSlot之后再扫log2(table.length-1)个
**/
private boolean cleanSomeSlots(int i, int n) {
// 标识是否有删除过staleSlot
boolean removed = false; Entry[] tab = table; int len = tab.length; do { i = nextIndex(i, len); Entry e = tab[i]; if (e != null && e.get() == null) { n = len; removed = true; i = expungeStaleEntry(i); } } while ( (n >>>= 1) != 0); return removed; }
从i+1处开始,往后扫描,如果遇到staleSlot,则执行expungeStaleEntry,往后扫描log2(n)次结束循环,n为传入的参数,如果发现staleSlot,则将n更新为Entry数组长度len。
这个设计非常巧妙,试探性的扫描一些单元看是否能发现staleSlot(不新鲜的entrys,也就是key=null)。当一个新元素添加进来或者一个staleSlot被清除的时候,会调用这个方法。该方法扫描元素的数量是对数级的,如果不扫描就不能及时清除key为null的entry(会浪费内存),如果全数组扫描则会导致一次插入的时间复杂度为O(n),采用这种试探性的扫描方式其实是一种在功能和性能之间的平衡,尽最大努力清理垃圾,又不导致过于消耗性能。
如果插入了新Entry,且执行了cleanSomeSlots之后size的数量还是大于阈值的话,这时就需要rehash扩容了:
private void rehash() { expungeStaleEntries(); // Use lower threshold for doubling to avoid hysteresis if (size >= threshold - threshold / 4) resize(); }
// 扫描全表,清除所有staleSlot private void expungeStaleEntries() { Entry[] tab = table; int len = tab.length; for (int j = 0; j < len; j++) { Entry e = tab[j]; if (e != null && e.get() == null) expungeStaleEntry(j); } }
// 将表容量扩大一倍 private void resize() { Entry[] oldTab = table; int oldLen = oldTab.length; int newLen = oldLen * 2; Entry[] newTab = new Entry[newLen]; int count = 0; for (int j = 0; j < oldLen; ++j) { Entry e = oldTab[j]; if (e != null) { ThreadLocal<?> k = e.get(); if (k == null) { e.value = null; // Help the GC } else { int h = k.threadLocalHashCode & (newLen - 1); while (newTab[h] != null) h = nextIndex(h, newLen); newTab[h] = e; count++; } } } setThreshold(newLen); size = count; table = newTab; }
首先扫描全表,清除所有staleSlot,如果这还不能减小size,则将table容量扩大一倍。扩容的逻辑比较简单,根据新数组容量来计算新的数组下标,如果存在hash冲突就往后找,直到Entry为空则把元素放进去。
到这里我们学习了ThreadLocal的基本原理、核心数据结构、最常用的get和set方法,是不是对ThreadLocal有了更深入的了解呢?如果有,那非常高兴我的文章能给你带来一丁点价值^_^
前面有讲到,ThreadLocalMap中的Entry其类型是属于弱引用(继承了WeakReference),被弱引用指向的对象,在下一次GC时是会被回收的,除非这个对象还有强引用指向它(对Java中强、软、弱、虚引用不清楚的同学可以详细了解下),之所以这样设计,我的理解是Entry是存在ThreadLocalMap中,而这个map又是保存在线程thread中的,用户是不能直接获取到的,也是不能直接操作的,也就会影响到垃圾回收。为了避免因为ThreadLocalMap存储了ThreadLocal对象而影响到ThreadLocal对象的垃圾回收,JDK的设计者把主动权完全交给调用方,一旦调用方不想使用,只需设置ThreadLocal对象为null,内存就可以被回收掉了,这也是弱引用的一个主要使用场景。
另一方面,在set和getEntry的过程中会频繁的去清理stale entry,以及时释放空余位置,这样就可以及时清除value,因为value是我们要保存到ThreadLocal中的值,而这是强引用,即便是key被回收了,value依然不会被回收。
虽然ThreadLocal中做了种种设计来防止内存泄漏,但是如果使用不当还是会导致内存泄漏,我这里借用一个网上的例子,一起来感受下:
public class ThreadLocalLeakDemo { public static void main(String[] args) { new Thread(new Runnable() { @Override public void run() { for(int i = 0; i< 1000 ;i++) { TestClass t = new TestClass(i); t.printId();...
// 行1,注释掉这一行时不会导致内存溢出 t = null;
// 行2,注释掉这一行时会导致内存溢出 t.threadLocal.remove(); } } }).start();; } static class TestClass{ private int id; private int[] arr;
// 注意,这是一个普通成员哦 private ThreadLocal<TestClass> threadLocal; TestClass(int id){ this.id = id; arr = new int[1000000]; threadLocal = new ThreadLocal(); threadLocal.set(this); } public void printId() { System.out.println(threadLocal.get().id); } } }
/**
* 注释行2,放开行1时,会导致内存溢出,结果如下:
**/
/**
* 注释行1,放开行2时,不会导致内存泄漏,结果如下:
**/
...上面其实就是改了一行代码,就导致内存溢出,增加的那一步操作就是调用了ThreadLocal的remove,那我们就来看看remove的逻辑:
移除元素的逻辑很简单,根据传入的key定位到数组下标i,从这个下标开始往后循环,直到遇到Entry为空时停止循环。如果找到key对应的entry,则调用Entry的clear方法。
private void remove(ThreadLocal<?> key) { Entry[] tab = table; int len = tab.length; int i = key.threadLocalHashCode & (len-1); for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) { if (e.get() == key) { e.clear(); expungeStaleEntry(i); return; } } }
结合上面的例子和源码,我们解释一下为什么没有调用remove方法会导致内存溢出。如上,在不调用remove时,每一次循环都会插入一个新的Entry对象到ThreadLocalMap中,这个Entry是指向一个新的ThreadLocal对象,对于这个ThreadLocal对象存在两个引用:
由于强引用一直存在,而t=null并不能让value不可达,因为value是保存在线程本地内存中的,所以没法回收这个新的ThreadLocal对象,导致一直堆积,最终报OOM
而如果调用remove的话,则会直接将对应Entry以及其保存的value清空,这样就不会内存泄漏了。
其实上面的例子是使用不当导致的,如果将ThreadLocal成员变量置为static,也不会出现这个问题,因为即便有1000次循环,但是都是用的同一个ThreadLocal,在线程本地始终只有一份,用private static来修饰ThreadLocal也是一个官方推荐的惯用法。
以上为个人总结,如有不对,烦请指正。
原文:https://www.cnblogs.com/volcano-liu/p/10712524.html