前言
为了巩固学习Spring框架,我尝试通过手写spring,实现SpringMVC基本功能来学习Spring源码。
实现功能
这次先写了一个简易的框架,实现了最基本的IoC功能,以及springmvc中常用的注解,具体如下:
@Controller
@RequestMapping
@RequestParam
@Autowired
@Bean
代码
内嵌服务器
采用了apache的embed-core
包在项目中内置了一个tomcat服务器,直接调用一个自己写的start()方法即可启动。本包在gradle中依赖如下:
1 2
| group: 'org.apache.tomcat.embed', name: 'tomcat-embed-core', version: '8.5.23'
|
下面是TomcatServer类:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
|
public class TomcatServer {
private Tomcat tomcat; private String[] args;
public TomcatServer(String[] args) { this.args = args; }
public void startTomcat() throws LifecycleException { tomcat = new Tomcat(); tomcat.setPort(6699); tomcat.start(); Context context = new StandardContext(); context.setPath(""); context.addLifecycleListener(new Tomcat.FixContextListener()); DispatcherServlet servlet = new DispatcherServlet(); Tomcat.addServlet(context, "dispatcherServlet", servlet).setAsyncSupported(true); context.addServletMappingDecoded("/", "dispatcherServlet"); tomcat.getHost().addChild(context);
Thread awaitThread = new Thread("tomcat_await_thread"){ @Override public void run() { TomcatServer.this.tomcat.getServer().await(); } }; awaitThread.setDaemon(false); awaitThread.start(); } }
|
DispatcherServlet(用来维护Servlet)
DispatcherServlet类进行servlet的维护,实现请求分发。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
|
public class DispatcherServlet implements Servlet {
@Override public void init(ServletConfig config) throws ServletException {
}
@Override public ServletConfig getServletConfig() { return null; }
@Override public void service(ServletRequest req, ServletResponse res) throws ServletException, IOException { for (MappingHandler mappingHandler : HandlerManager.mappingHandlerList) { try { if (mappingHandler.handle(req, res)) { return; } } catch (IllegalAccessException | InstantiationException | InvocationTargetException e) { e.printStackTrace(); } } }
@Override public String getServletInfo() { return null; }
@Override public void destroy() {
} }
|
MappingHandler
MappingHandler用来处理经过dispatcherservlet的请求uri,每个MappingHandler处理其匹配的uri请求。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
|
public class MappingHandler {
private String uri; private Method method; private Class<?> controller; private String[] args;
public boolean handle(ServletRequest request, ServletResponse response) throws IllegalAccessException, InstantiationException, InvocationTargetException, IOException { String requestUri = ((HttpServletRequest)request).getRequestURI(); if (!uri.equals(requestUri)) { return false; }
Object[] parameters = new Object[args.length]; for (int i = 0; i < args.length; i++) { parameters[i] = request.getParameter(args[i]); }
Object ctl = BeanFactory.getBean(controller); Object res = method.invoke(ctl, parameters); response.getWriter().println(res.toString()); return true; } public MappingHandler(String uri, Method method, Class<?> controller, String[] args) { this.uri = uri; this.method = method; this.controller = controller; this.args = args; } }
|
HandlerManager
HandlerManager负责处理controller注解,将其转换为handler。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
|
public class HandlerManager {
public static List<MappingHandler> mappingHandlerList = new ArrayList<>();
public static void resolveMappingHandler(List<Class<?>> classList) { for (Class<?> cls : classList) { if (cls.isAnnotationPresent(Controller.class)) { parseHandlerFromController(cls); } } }
private static void parseHandlerFromController(Class<?> cls) { Method[] methods = cls.getDeclaredMethods(); for (Method method : methods) { if (!method.isAnnotationPresent(RequestMapping.class)) { continue; } String uri = method.getDeclaredAnnotation(RequestMapping.class).value(); List<String> paramNameList = new ArrayList<>(); for (Parameter parameter : method.getParameters()) { if (parameter.isAnnotationPresent(RequestParam.class)) { paramNameList.add(parameter.getDeclaredAnnotation(RequestParam.class).value()); } } String[] params = paramNameList.toArray(new String[paramNameList.size()]); MappingHandler mappingHandler = new MappingHandler(uri, method, cls, params); HandlerManager.mappingHandlerList.add(mappingHandler); } } }
|
ClassScanner
负责扫描一个jar包下面的所有类。用于初始化时获取所有类的列表,从而进行依赖注入,以及controller扫描。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
|
public class ClassScanner {
public static List<Class<?>> scanClass(String packageName) throws IOException, ClassNotFoundException { List<Class<?>> classList = new ArrayList<>(); String path = packageName.replace(".", "/"); ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); Enumeration<URL> resources = classLoader.getResources(path); while (resources.hasMoreElements()) { URL resource = resources.nextElement(); if (resource.getProtocol().contains("jar")) { JarURLConnection jarURLConnection = (JarURLConnection) resource.openConnection(); String jarFilePath = jarURLConnection.getJarFile().getName(); classList.addAll(getClassesFromJar(jarFilePath, path)); } else { } } return classList; }
private static List<Class<?>> getClassesFromJar(String jarFilePath, String path) throws IOException, ClassNotFoundException { List<Class<?>> classes = new ArrayList<>(); JarFile jarFile = new JarFile(jarFilePath); Enumeration<JarEntry> jarEntries = jarFile.entries(); while (jarEntries.hasMoreElements()) { JarEntry jarEntry = jarEntries.nextElement(); String entryName = jarEntry.getName(); if (entryName.startsWith(path) && entryName.endsWith(".class")) { String classFullName = entryName.replace("/", ".").substring(0, entryName.length() - 6); classes.add(Class.forName(classFullName)); } } return classes; } }
|
Controller、RequestMapping和RequestParam注解
关于注解,可以参考我之前的文章:Java的注解
Controller
1 2 3 4 5 6 7 8 9
|
@Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface Controller { }
|
RequestMapping
1 2 3 4 5 6 7 8 9 10
|
@Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.METHOD) public @interface RequestMapping { String value(); }
|
RequestParam
1 2 3 4 5 6 7 8 9 10
|
@Documented @Retention(RetentionPolicy.RUNTIME) @Trget(ElementType.PARAMETER) public @interface RequestParam { String value(); }
|
BeanFactory
这个是IoC的重点部分,所有的依赖都是由BeanFactory处理并注入的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| package top.makersy.beans;
import top.makersy.web.mvc.Controller;
import java.lang.reflect.Field; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap;
public class BeanFactory {
private static Map<Class<?>, Object> classToBean = new ConcurrentHashMap<>();
public static Object getBean(Class<?> cls){ return classToBean.get(cls); }
public static void initBean(List<Class<?>> classList) throws Exception { List<Class<?>> toCreate = new ArrayList<>(classList); while (toCreate.size() != 0) { int remainSize = toCreate.size(); for (int i = 0; i < toCreate.size(); i++) { if (finishCreate(toCreate.get(i))) { toCreate.remove(i); } } if (toCreate.size() == remainSize) { throw new Exception("cycle dependency!"); } } }
private static boolean finishCreate(Class<?> cls) throws IllegalAccessException, InstantiationException { if (!cls.isAnnotationPresent(Bean.class) && !cls.isAnnotationPresent(Controller.class)) { return true; }
Object bean = cls.newInstance(); for (Field field : cls.getDeclaredFields()) { if (field.isAnnotationPresent(Autowired.class)) { Class<?> fieldType = field.getType(); Object relianBean = BeanFactory.getBean(fieldType); if (relianBean == null) { return false; } field.setAccessible(true); field.set(bean, relianBean); } } classToBean.put(cls, bean); return true; } }
|
Autowired、Bean注解
Autowired
1 2 3 4 5 6 7 8 9
|
@Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.FIELD) public @interface Autowired { }
|
Bean
1 2 3 4 5 6 7 8 9
|
@Documented @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.TYPE) public @interface Bean { }
|
最后
因为是刚开始,所以只实现了一些基础功能,ioc的循环依赖处理、aop都还没有实现,后面我会在此基础上进行完善。