手撸一个ORM框架

ORM(Object Relational Mapping)框架采用元数据来描述对象一关系映射细节,元数据一般采用XML格式,并且存放在专门的对象一映射文件中。

只要提供了持久化类与表的映射关系,ORM框架在运行时就能参照映射文件的信息,把对象持久化到数据库中。当前ORM框架主要有五种:Hibernate(Nhibernate),iBATIS,mybatis,EclipseLink,JFinal。

手写一个简单的ORM框架来实现CRUD功能,主要用到Java的反射和注解功能.大致步骤如下:

  • 数据库连接配置,包括配置文件,数据源连接池的实现
  • 实现注解功能,注解解析
  • 实体类与表字段映射
  • 实现crud,结果处理

数据库连接

1
2
3
4
driver=com.mysql.jdbc.Driver
url=jdbc:mysql://localhost:3306/orm?useUnicode=true&characterEncoding=utf-8&useSSL=true
username=root
password=123456

数据源配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import java.sql.Connection;
import java.sql.DriverManager;
import java.util.Properties;

/**
* @Author: Usher
* @Description:
*/
public class DBSource {
private String driver;
private String url;
private String username;
private String password;

public DBSource() {
}

public DBSource(Properties properties) {
this.driver = properties.getProperty("driver");
this.url = properties.getProperty("url");
this.username = properties.getProperty("username");
this.password = properties.getProperty("password");
}

/**
* 数据源工具类
* @return
* @throws Exception
*/
public Connection openConnection() throws Exception{
Class.forName(driver);
return DriverManager.getConnection(url, username, password);
}

public String getDriver() {
return driver;
}

public void setDriver(String driver) {
this.driver = driver;
}

public String getUrl() {
return url;
}

public void setUrl(String url) {
this.url = url;
}

public String getUsername() {
return username;
}

public void setUsername(String username) {
this.username = username;
}

public String getPassword() {
return password;
}

public void setPassword(String password) {
this.password = password;
}
}

注解功能

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* @Author: Usher
* @Description:
字段注解
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface Column {

String value() default "";

boolean isNull() default true;

boolean isId() default false;//primary key
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* @Author: Usher
* @Description:
表注解
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Table {
//注解属性结构
//数据类型
String value() default "";

}

实体类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import com.usher.annotation.Column;
import com.usher.annotation.Table;

/**
* @Author: Usher
* @Description:
* 实体类映射
*/
@Table(value = "sys_user")
public class User {
@Column(value = "user_id",isId = true)
private String id;
private String username;
@Column("name")
private String nickname;
private String password;
private String phone;

@Override
public String toString() {
return "User{" +
"id='" + id + '\'' +
", username='" + username + '\'' +
", nickname='" + nickname + '\'' +
", password='" + password + '\'' +
", phone='" + phone + '\'' +
'}';
}

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}

public String getUsername() {
return username;
}

public void setUsername(String username) {
this.username = username;
}

public String getNickname() {
return nickname;
}

public void setNickname(String nickname) {
this.nickname = nickname;
}

public String getPassword() {
return password;
}

public void setPassword(String password) {
this.password = password;
}

public String getPhone() {
return phone;
}

public void setPhone(String phone) {
this.phone = phone;
}
}

数据库操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
package com.usher.dao;

import java.util.List;

/**
* @Author: Usher
* @Description:
*/
public interface UserDao<T> {

public List<T> findAll();
public int save(T obj);

public int update(T obj);

public int delete(T obj);
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
package com.usher.dao.impl;

import com.usher.bean.User;
import com.usher.dao.UserDao;

import java.util.List;

/**
* @Author: Usher
* @Description:
*/
public class UserDaoImpl implements UserDao<User> {

@Override
public List<User> findAll() {
return null;
}

@Override
public int save(User obj) {
return 0;
}

@Override
public int update(User obj) {
return 0;
}

@Override
public int delete(User obj) {
return 0;
}
}

注解解析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package com.usher.utils;

import com.usher.annotation.Column;
import com.usher.annotation.Table;

import java.lang.reflect.Field;

/**
* @Author: Usher
* @Description:
* 实现orm框架解析注解
*/
public class ORMAnnotationUtil {
/**
*指定类上的注入的表名
* @return
*/
public static String getTableName(Class<?> beanClass) {
//通过反射获取@Table注解
Table table = beanClass.getAnnotation(Table.class);
if (table == null) {
return beanClass.getSimpleName().toLowerCase();
}

return table.value();
}

/**
* 返回指定字段列名
* @return
*/
public static String getColumnName(Field field) {
Column column = field.getAnnotation(Column.class);
if (column == null) {
//获取字段名称
return field.getName().toLowerCase();
}
return column.value();
}

/**
* 从类中查询主键的列
* @param cls
* @return
*/
public static Field findIdField(Class<?> cls) {
for (Field f : cls.getDeclaredFields()) {
if (isId(f)) {
return f;
}
}
return null;
}
public static boolean isId(Field field) {
Column column = field.getAnnotation(Column.class);
if (column != null) {
//获取字段主键
return column.isId();
}
return false;
}
}

数据源连接相关操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
package com.usher.orm;

import com.usher.bean.User;
import com.usher.utils.ORMAnnotationUtil;

import java.io.FileReader;
import java.io.IOException;
import java.lang.reflect.Field;
import java.sql.*;
import java.util.*;
import java.util.Date;

/**
* @Author: Usher
* @Description:
*/
public class DBSessionFactory {
//数据源
private DBSource dbSource;
//数据源连接属性
private Properties properties;
public DBSessionFactory() throws Exception {
properties = new Properties();
//从属性资源 加载key-value
properties.load(ClassLoader.getSystemResourceAsStream("resources/dbConfig.properties"));
//System.out.println(properties.getProperty("url"));
dbSource = new DBSource(properties);
//Connection connection = dbSource.openConnection();
//System.out.println("连接成功");
}

//打开一个数据库连接
public DBSession openSession() throws Exception {
return new DBSession(dbSource.openConnection());
}
/**
* 操作数据库,静态内部类封装相关操作
*/
public static class DBSession {
private Connection connection;//数据库连接对象

public DBSession(Connection connection) {
this.connection = connection;
}

/**
* 查询所有数据
* @param tClass
* @param <T>
* @return
* @throws IllegalAccessException
* @throws InstantiationException
*/
public <T> List<T> list(Class<T> tClass) throws IllegalAccessException, InstantiationException, SQLException {
//select * from tb
String sql = "select %s from %s";
//生成查询字段列表
StringBuilder columns = new StringBuilder();
Field[] fs = tClass.getDeclaredFields();
for (int i = 0,len = fs.length; i < len; i++) {
columns.append(ORMAnnotationUtil.getColumnName(fs[i]));
if (i != len - 1) {
columns.append(",");
}
}

//sql
sql = String.format(sql, columns.toString(), ORMAnnotationUtil.getTableName(tClass));

System.out.println("Statement SQL: " + sql);
//execute sql(Statement,PrepareStatement)
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(sql);

//多行数据处理
List<T> list = listResultHandler(tClass, resultSet);
statement.close();
return list;
}

private <T> List<T> listResultHandler(Class<T> tClass, ResultSet resultSet) throws SQLException, IllegalAccessException, InstantiationException {
List<T> list = new ArrayList<>();

T obj = null;
Field[] fs = tClass.getDeclaredFields();
while (resultSet.next()) {
//实例化实体类对象
obj = tClass.newInstance();

//读取指定字段的数据注入到实体类属性
for (Field f : fs) {
//可访问性
f.setAccessible(true);
Class<?> type = f.getType();
if (type == String.class) {
f.set(obj, resultSet.getString(ORMAnnotationUtil.getColumnName(f)));
} else if (type == int.class || type == Integer.class) {
f.set(obj, resultSet.getInt(ORMAnnotationUtil.getColumnName(f)));
} else if (type == double.class || type == Double.class) {
f.set(obj, resultSet.getDouble(ORMAnnotationUtil.getColumnName(f)));
} else if (type == Date.class) {
f.set(obj, resultSet.getDate(ORMAnnotationUtil.getColumnName(f)));
}
}

//将实体类添加到list
list.add(obj);
}
return list;
}

private <T> T oneResultHandler(Class<T> tClass, ResultSet resultSet) throws SQLException, IllegalAccessException, InstantiationException {

T obj = null;
Field[] fs = tClass.getDeclaredFields();
while (resultSet.next()) {
//实例化实体类对象
obj = tClass.newInstance();

//读取指定字段的数据注入到实体类属性
for (Field f : fs) {
//可访问性
f.setAccessible(true);
Class<?> type = f.getType();
if (type == String.class) {
f.set(obj, resultSet.getString(ORMAnnotationUtil.getColumnName(f)));
} else if (type == int.class || type == Integer.class) {
f.set(obj, resultSet.getInt(ORMAnnotationUtil.getColumnName(f)));
} else if (type == double.class || type == Double.class) {
f.set(obj, resultSet.getDouble(ORMAnnotationUtil.getColumnName(f)));
} else if (type == Date.class) {
f.set(obj, resultSet.getDate(ORMAnnotationUtil.getColumnName(f)));
}
}
}
return obj;
}

public int save(Object object) throws SQLException, IllegalAccessException {
//插入一条数据
//生成sql:insert into tb() values()
String sql = "insert into %s(%s) values(%s)";
StringBuilder columns = new StringBuilder();
StringBuilder params = new StringBuilder();

//获取实体对象的所有字段
Field[] fs = object.getClass().getDeclaredFields();
for (int i = 0, len = fs.length; i < len; i++) {
columns.append(ORMAnnotationUtil.getColumnName(fs[i]));
params.append("?");

if (i != len - 1) {
columns.append(",");
params.append(",");
}
}

//生成sql
sql = String.format(sql, ORMAnnotationUtil.getTableName(object.getClass()),
columns.toString(), params.toString());
System.out.println("Insert SQL: " + sql);

//创建预处理SQL对象
PreparedStatement preparedStatement = connection.prepareStatement(sql);
//设置预处理的参数
int i = 1;//sql从1开始
for (Field f : fs) {
//可访问性
f.setAccessible(true);
Class<?> type = f.getType();
if (type == String.class) {
preparedStatement.setString(i, String.valueOf(f.get(object)));
} else if (type == int.class || type == Integer.class) {
preparedStatement.setInt(i, f.getInt(object));
} else if (type == double.class || type ==Double.class) {
preparedStatement.setDouble(i, f.getDouble(object));
}
i++;
}
//执行预处理语句
int rows = preparedStatement.executeUpdate();
preparedStatement.close();
return rows;
}

public int update(Object object) throws IllegalAccessException, SQLException {
String sql = "update %s set %s where %s";
StringBuilder updateColumns = new StringBuilder();
String where = "";

Field[] fs = object.getClass().getDeclaredFields();
//更新字段集合
List<Field> updateFields = new ArrayList<>();
Field f = null;
for (int i = 0, len = fs.length; i < len; i++) {
f = fs[i];
//判断字段是否为主键
if (ORMAnnotationUtil.isId(f)) {
f.setAccessible(true);
where = ORMAnnotationUtil.getColumnName(f) + "=";
//判断主键字段类型
if (f.getType() == String.class) {
where += "'" + String.valueOf(f.get(object)) + "'";
} else {
where += f.get(object);
}
continue;
}
//非主键
updateColumns.append(ORMAnnotationUtil.getColumnName(f)).append("=?");
if (i != len - 1) {
updateColumns.append(",");
}
//将更新的字段添加到集合
updateFields.add(f);
f = null;
}
sql = String.format(sql,
ORMAnnotationUtil.getTableName(object.getClass()),
updateColumns.toString(), where);

System.out.println("Update SQL:" + sql);

//执行
PreparedStatement preparedStatement =connection.prepareStatement(sql);
Class<?> type = null;
for (int i = 0, len = updateFields.size(); i < len; i++) {
f = updateFields.get(i);
f.setAccessible(true);

type = f.getType();//字段类型
if (type == String.class) {
preparedStatement.setString(i + 1, String.valueOf(f.get(object)));
} else if (type == int.class || type == Integer.class) {
preparedStatement.setInt(i + 1, f.getInt(object));
} else if (type == double.class || type == Double.class) {
preparedStatement.setDouble(i + 1, f.getDouble(object));
} else if (type == long.class || type == Long.class) {
preparedStatement.setLong(i + 1, f.getLong(object));
} else if (type == float.class || type == Float.class) {
preparedStatement.setFloat(i + 1, f.getFloat(object));
} else if (type == Date.class) {
Date date = (Date) f.get(object);
preparedStatement.setDate(i + 1, new java.sql.Date(date.getTime()));
}
}

int rows = preparedStatement.executeUpdate();
preparedStatement.close();

return rows;
}

public <T> T getById(Class<T> tClass, Object id) throws SQLException, InstantiationException, IllegalAccessException {
Field idField = ORMAnnotationUtil.findIdField(tClass);
String where = ORMAnnotationUtil.getColumnName(idField) + "=";

if (idField.getType() == String.class) {
where += "'" + id + "'";
} else {
where += id;
}

//select * from tb where {id} = ?
String sql = String.format("select * from %s where %s",
ORMAnnotationUtil.getTableName(tClass),
where);
System.out.println("Find By Id SQL: " + sql);
Statement statement = connection.createStatement();
ResultSet resultSet = statement.executeQuery(sql);

//单行数据
T t = oneResultHandler(tClass, resultSet);
statement.close();
return t;

}

public int delete(Class cls, Object objectId) throws SQLException {
Field idField = ORMAnnotationUtil.findIdField(cls);
String where = ORMAnnotationUtil.getColumnName(idField) + "=";

if (String.class == objectId.getClass()) {
where += "'" + objectId + "'";
} else {
where += objectId;
}

Statement statement = connection.createStatement();
int rows = statement.executeUpdate("delete from " + ORMAnnotationUtil.getTableName(cls) +" where " + where);

statement.close();
return rows;
}
/**
* 关闭连接
*/
public void close() {
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
e.printStackTrace();
}finally {
connection = null;
}
}
}
}
public static void main(String[] args) throws Exception {
//Test
DBSessionFactory sessionFactory = new DBSessionFactory();
DBSession session = sessionFactory.openSession();
List<User> userList = sessionFactory.openSession().list(User.class);
System.out.println(userList);
User user = new User();
// user.setId(UUID.randomUUID().toString().replaceAll("-", ""));
user.setId("1");
user.setUsername("usher");
user.setPassword("1234567");
user.setNickname("usher");
user.setPhone("323232");

//User user1 = session.getById(User.class, "1");
System.out.println(session.delete(User.class, "1"));
//System.out.println(session.update(user));
//System.out.println(user1);

}

}